Repository: mit-han-lab/deepcompressor
Branch: main
Commit: 69f3473f5e1c
Files: 229
Total size: 3.7 MB
Directory structure:
gitextract_0u4zluxv/
├── .gitignore
├── LICENSE
├── README.md
├── assets/
│ ├── diffusion/
│ │ └── .gitkeep
│ └── llm/
│ └── .gitkeep
├── deepcompressor/
│ ├── __init__.py
│ ├── app/
│ │ ├── __init__.py
│ │ ├── diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── cache/
│ │ │ │ ├── __init__.py
│ │ │ │ └── config.py
│ │ │ ├── config.py
│ │ │ ├── dataset/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base.py
│ │ │ │ ├── calib.py
│ │ │ │ ├── collect/
│ │ │ │ │ ├── calib.py
│ │ │ │ │ └── utils.py
│ │ │ │ └── data/
│ │ │ │ ├── COCO/
│ │ │ │ │ ├── COCO.py
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── DCI/
│ │ │ │ │ ├── DCI.py
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── MJHQ/
│ │ │ │ │ ├── MJHQ.py
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── dump.py
│ │ │ ├── eval/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── metrics/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── fid.py
│ │ │ │ ├── image_reward.py
│ │ │ │ ├── multimodal.py
│ │ │ │ ├── run.py
│ │ │ │ └── similarity.py
│ │ │ ├── nn/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── patch.py
│ │ │ │ └── struct.py
│ │ │ ├── pipeline/
│ │ │ │ ├── __init__.py
│ │ │ │ └── config.py
│ │ │ ├── ptq.py
│ │ │ ├── quant/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── activation.py
│ │ │ │ ├── config.py
│ │ │ │ ├── quantizer/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── config.py
│ │ │ │ │ └── quantizer.py
│ │ │ │ ├── rotate.py
│ │ │ │ ├── smooth.py
│ │ │ │ ├── utils.py
│ │ │ │ └── weight.py
│ │ │ └── utils.py
│ │ └── llm/
│ │ ├── __init__.py
│ │ ├── cache/
│ │ │ ├── __init__.py
│ │ │ └── config.py
│ │ ├── config.py
│ │ ├── eval/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── config.py
│ │ │ ├── custom.py
│ │ │ ├── lm_eval.py
│ │ │ └── longbench/
│ │ │ ├── __init__.py
│ │ │ ├── eval.py
│ │ │ ├── metrics.py
│ │ │ └── task2prompt.json
│ │ ├── model/
│ │ │ ├── __init__.py
│ │ │ └── config.py
│ │ ├── nn/
│ │ │ ├── __init__.py
│ │ │ ├── patch.py
│ │ │ └── struct.py
│ │ ├── ptq.py
│ │ └── quant/
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── config.py
│ │ ├── dataset.py
│ │ ├── quantizer/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── quantizer.py
│ │ ├── reorder.py
│ │ ├── rotate.py
│ │ ├── smooth.py
│ │ ├── utils.py
│ │ └── weight.py
│ ├── backend/
│ │ ├── __init__.py
│ │ ├── nunchaku/
│ │ │ ├── __init__.py
│ │ │ ├── convert.py
│ │ │ ├── convert_lora.py
│ │ │ └── utils.py
│ │ ├── qserve/
│ │ │ ├── __init__.py
│ │ │ ├── convert.py
│ │ │ └── utils.py
│ │ ├── tinychat/
│ │ │ ├── __init__.py
│ │ │ ├── convert.py
│ │ │ ├── csrc/
│ │ │ │ ├── load.py
│ │ │ │ ├── pybind.cpp
│ │ │ │ ├── quantization/
│ │ │ │ │ ├── dequantize.cuh
│ │ │ │ │ ├── gemm/
│ │ │ │ │ │ ├── gemm_cuda.cu
│ │ │ │ │ │ ├── gemm_cuda.h
│ │ │ │ │ │ └── semaphore.h
│ │ │ │ │ └── gemv/
│ │ │ │ │ ├── gemv_cuda.cu
│ │ │ │ │ └── gemv_cuda.h
│ │ │ │ └── utils.cuh
│ │ │ ├── linear.py
│ │ │ └── utils.py
│ │ └── utils.py
│ ├── calib/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── lowrank.py
│ │ │ ├── range.py
│ │ │ ├── reorder.py
│ │ │ ├── rotation.py
│ │ │ ├── search.py
│ │ │ └── smooth.py
│ │ ├── lowrank.py
│ │ ├── metric.py
│ │ ├── range.py
│ │ ├── reorder.py
│ │ ├── rotate.py
│ │ ├── search.py
│ │ └── smooth.py
│ ├── csrc/
│ │ ├── load.py
│ │ ├── pybind.cpp
│ │ └── quantize/
│ │ ├── quantize.cu
│ │ └── quantize.h
│ ├── data/
│ │ ├── __init__.py
│ │ ├── cache.py
│ │ ├── codebook.py
│ │ ├── common.py
│ │ ├── dtype.py
│ │ ├── range.py
│ │ ├── scale.py
│ │ ├── tensor.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── dtype.py
│ │ │ ├── reshape.py
│ │ │ ├── scale.py
│ │ │ └── shape.py
│ │ └── zero.py
│ ├── dataset/
│ │ ├── __init__.py
│ │ ├── action.py
│ │ ├── cache.py
│ │ └── config.py
│ ├── nn/
│ │ ├── __init__.py
│ │ ├── patch/
│ │ │ ├── __init__.py
│ │ │ ├── conv.py
│ │ │ ├── linear.py
│ │ │ ├── lowrank.py
│ │ │ └── sdpa.py
│ │ └── struct/
│ │ ├── __init__.py
│ │ ├── attn.py
│ │ └── base.py
│ ├── quantizer/
│ │ ├── __init__.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── kernel.py
│ │ │ └── lowrank.py
│ │ ├── impl/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── info.py
│ │ │ ├── scale.py
│ │ │ ├── simple.py
│ │ │ └── ste.py
│ │ ├── kernel/
│ │ │ ├── __init__.py
│ │ │ ├── gptq.py
│ │ │ └── rtn.py
│ │ └── processor.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── config/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── model.py
│ │ │ ├── output.py
│ │ │ └── path.py
│ │ ├── dataclass.py
│ │ ├── hooks/
│ │ │ ├── __init__.py
│ │ │ ├── branch.py
│ │ │ ├── hook.py
│ │ │ ├── packager.py
│ │ │ └── processor.py
│ │ ├── math/
│ │ │ ├── __init__.py
│ │ │ ├── functional.py
│ │ │ └── hadamard.py
│ │ ├── patch.py
│ │ └── tools/
│ │ ├── __init__.py
│ │ ├── logging.py
│ │ └── sys.py
│ └── version.py
├── environment.yml
├── examples/
│ ├── diffusion/
│ │ ├── .gitignore
│ │ ├── README.md
│ │ ├── configs/
│ │ │ ├── __default__.yaml
│ │ │ ├── collect/
│ │ │ │ └── qdiff.yaml
│ │ │ ├── lora/
│ │ │ │ ├── __default__.yaml
│ │ │ │ └── flux.1-dev/
│ │ │ │ ├── anime.yaml
│ │ │ │ ├── ghibsky.yaml
│ │ │ │ ├── realism.yaml
│ │ │ │ ├── sketch.yaml
│ │ │ │ └── yarn.yaml
│ │ │ ├── model/
│ │ │ │ ├── flux.1-dev.yaml
│ │ │ │ ├── flux.1-schnell.yaml
│ │ │ │ ├── pixart-sigma.yaml
│ │ │ │ └── sana-1.6b.yaml
│ │ │ ├── svdquant/
│ │ │ │ ├── __default__.yaml
│ │ │ │ ├── fast.yaml
│ │ │ │ ├── gptq.yaml
│ │ │ │ ├── int4.yaml
│ │ │ │ └── nvfp4.yaml
│ │ │ └── text/
│ │ │ ├── __default__.yaml
│ │ │ └── awq.yaml
│ │ ├── prompts/
│ │ │ ├── lora/
│ │ │ │ ├── anime.yaml
│ │ │ │ ├── ghibsky.yaml
│ │ │ │ ├── realism.yaml
│ │ │ │ ├── sketch.yaml
│ │ │ │ └── yarn.yaml
│ │ │ └── qdiff.yaml
│ │ └── scripts/
│ │ └── svdquant.sh
│ └── llm/
│ ├── .gitignore
│ ├── README.md
│ ├── configs/
│ │ ├── __default__.yaml
│ │ ├── awq.yaml
│ │ ├── gptq.yaml
│ │ ├── ooo.yaml
│ │ ├── qoq-g128.yaml
│ │ ├── qoq-gchn.yaml
│ │ ├── smoothquant-dynamic.yaml
│ │ └── smoothquant-static.yaml
│ └── scripts/
│ ├── awq.sh
│ ├── gptq.sh
│ ├── qoq.sh
│ └── smoothquant.sh
└── pyproject.toml
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VS Code
.vscode/
!.vscode/settings.json
.DS_Store
*.log
*.pt
.tmp/
runs
exps
runs/
exps/
wandb
wandb/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [2024] Yujun Lin, Muyang Li, Zhekai Zhang, Haotian Tang, Shang Yang, Song Han
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Model Compression Toolbox for Large Language Models and Diffusion Models
## News
- [2025/02] 🎉 [**QServe**](https://arxiv.org/abs/2405.04532) has been accepted to MLSys 2025!
- [2025/01] 🎉 [**SVDQuant**](https://arxiv.org/abs/2411.05007) has been accepted to ICLR 2025 (Spotlight)!
- [2024/12] 🎉 [**QServe**](https://github.com/mit-han-lab/qserve) has been integratedd into NVIDIA [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama)!
- [2024/11] 🔥 Our latest **W4A4** diffusion model quantization work [**SVDQuant**](https://arxiv.org/abs/2411.05007) algorithm and [**Nunchaku**](https://github.com/mit-han-lab/nunchaku) system is publicly released! Check our [paper](http://arxiv.org/abs/2411.05007)!
- [2024/05] 🔥 Our latest **W4A8KV4** LLM quantization work **QoQ** algorithm and **QServe** system is publicly released! **QoQ** is short for *quattuor-octō-quattuor* which is 4-8-4 in latin. Check our [paper](https://arxiv.org/abs/2405.04532)!
## Key Features
***DeepCompressor*** is an open source model compression toolbox for large language models and diffusion models based on PyTorch. DeepCompressor currently supports fake quantization with any integer and floating-point data type within 8 bits, e.g., INT8, INT4 and FP4_E2M1. Here are examples that implement the following algorithms.
+ [Post-training quantization for large language models](/examples/llm/):
+ Weight-only Quantization
+ [AWQ (W4A16)](/examples/llm/configs/awq.yaml)
+ [GPTQ (W4A16)](/examples/llm/configs/gptq.yaml)
+ Weight-Activation Quantization
+ [SmoothQuant (W8A8)](/examples/llm/configs/smoothquant-static.yaml)
+ Weight-Activation and KV-Cache Quantization
+ [QoQ (W4A8KV4)](/examples/llm/)
+ [Post-training quantization for diffusion models](/examples/diffusion/):
+ Weight-Activation Quantization
+ [SVDQuant (W4A4)](/examples/diffusion/)
DeepCompressor also contains examples that integrate with other inference libraries.
+ [Deploy weight-only quantized LLMs with TinyChat](/examples/llm/)
+ [Deploy quantized LLMs with QServe]((/examples/llm/))
+ [Deploy quantized diffusion models with Nunchaku](/examples/diffusion/)
## Installation
### Install from Source
1. Clone this repository and navigate to deepcompressor folder
```
git clone https://github.com/mit-han-lab/deepcompressor
cd deepcompressor
```
2. Install Package
```
conda env create -f environment.yml
poetry install
```
## Highlights
### SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
[[Website](https://hanlab.mit.edu/projects/svdquant)][[Paper](http://arxiv.org/abs/2411.05007)][[Nunchaku Inference System](https://github.com/mit-han-lab/nunchaku)]
Diffusion models have been proven highly effective at generating high-quality images. However, as these models grow larger, they require significantly more memory and suffer from higher latency, posing substantial challenges for deployment. In this work, we aim to accelerate diffusion models by quantizing their weights and activations to 4 bits. At such an aggressive level, both weights and activations are highly sensitive, where conventional post-training quantization methods for large language models like smoothing become insufficient. To overcome this limitation, we propose **SVDQuant**, a new 4-bit quantization paradigm. Different from smoothing which redistributes outliers between weights and activations, our approach absorbs these outliers using a low-rank branch. We first consolidate the outliers by shifting them from activations to weights, then employ a high-precision low-rank branch to take in the weight outliers with Singular Value Decomposition (SVD). This process eases the quantization on both sides. However, naïvely running the low-rank branch independently incurs significant overhead due to extra data movement of activations, negating the quantization speedup. To address this, we co-design an inference engine **Nunchaku** that fuses the kernels of the low-rank branch into those of the low-bit branch to cut off redundant memory access. It can also seamlessly support off-the-shelf low-rank adapters (LoRAs) without the need for re-quantization. Extensive experiments on SDXL, PixArt-∑, and FLUX.1 validate the effectiveness of SVDQuant in preserving image quality. We reduce the memory usage for the 12B FLUX.1 models by 3.5×, achieving 3.0× speedup over the 4-bit weight-only quantized baseline on the 16GB laptop 4090 GPU, paving the way for more interactive applications on PCs.


#### Quality Evaluation
Below is the quality and similarity evaluated with 5000 samples from MJHQ-30K dataset. IR means ImageReward. Our 4-bit results outperform other 4-bit baselines, effectively preserving the visual quality of 16-bit models.
| Model | Precision | Method | FID ($\downarrow$) | IR ($\uparrow$) | LPIPS ($\downarrow$) | PSNR( $\uparrow$) |
|----------------------------|-----------|-----------|--------------------|-----------------|----------------------|-------------------|
| FLUX.1-dev (50 Steps) | BF16 | -- | 20.3 | 0.953 | -- | -- |
| | W4A16 | NF4 | 20.6 | 0.910 | 0.272 | 19.5 |
| | INT W4A4 | | 20.2 | 0.908 | 0.322 | 18.5 |
| | INT W4A4 | SVDQuant | 19.9 | 0.935 | 0.223 | 21.0 |
| | NVFP4 | | 20.3 | 0.961 | 0.345 | 16.3 |
| | NVFP4 | SVDQuant | 20.3 | 0.945 | 0.205 | 21.5 |
| FLUX.1-schnell (4 Steps) | BF16 | -- | 19.2 | 0.938 | -- | -- |
| | W4A16 | NF4 | 18.9 | 0.943 | 0.257 | 18.2 |
| | INT W4A4 | | 18.1 | 0.962 | 0.345 | 16.3 |
| | INT W4A4 | SVDQuant | 18.3 | 0.951 | 0.257 | 18.3 |
| | NVFP4 | | 19.0 | 0.952 | 0.276 | 17.6 |
| | NVFP4 | SVDQuant | 18.9 | 0.966 | 0.228 | 19.0 |
| SANA-1.6b (20 Steps) | BF16 | -- | 20.6 | 0.952 | -- | -- |
| | INT W4A4 | | 20.5 | 0.894 | 0.339 | 15.3 |
| | INT W4A4 | SVDQuant | 19.3 | 0.935 | 0.220 | 17.8 |
| | NVFP4 | | 19.7 | 0.929 | 0.236 | 17.4 |
| | NVFP4 | SVDQuant | 20.2 | 0.941 | 0.176 | 19.0 |
| PixArt-Sigma (20 Steps) | FP16 | -- | 16.6 | 0.944 | -- | -- |
| | INT W4A8 | ViDiT-Q | 37.3 | 0.573 | 0.611 | 12.0 |
| | INT W4A4 | SVDQuant | 19.2 | 0.878 | 0.323 | 17.6 |
| | NVFP4 | | 31.8 | 0.660 | 0.517 | 14.8 |
| | NVFP4 | SVDQuant | 16.6 | 0.940 | 0.271 | 18.5 |
### QServe: W4A8KV4 Quantization for Efficient LLM Serving
[[Website](https://hanlab.mit.edu/projects/qserve)][[Paper](https://arxiv.org/abs/2405.04532)][[QoQ Algorithm Code](/examples/llm)][[QServe GPU System](https://github.com/mit-han-lab/qserve)]
Quantization can accelerate large language model (LLM) inference. Going beyond INT8 quantization, the research community is actively exploring even lower precision, such as INT4. Nonetheless, state-of-the-art INT4 quantization techniques only accelerate low-batch, edge LLM inference, failing to deliver performance gains in large-batch, cloud-based LLM serving. We uncover a critical issue: existing INT4 quantization methods suffer from significant runtime overhead (20-90%) when **dequantizing either weights or partial sums** on GPUs. To address this challenge, we introduce **QoQ**, a W4A8KV4 quantization algorithm with 4-bit weight, 8-bit activation, and 4-bit KV cache. QoQ stands for **quattuor-octo-quattuor**, which represents 4-8-4 in Latin. QoQ is implemented by the **QServe** inference library that achieves measured speedup. The key insight driving QServe is that the efficiency of LLM serving on GPUs is critically influenced by **operations on low-throughput CUDA cores**. Building upon this insight, in QoQ algorithm, we introduce progressive quantization that can allow low dequantization overhead in W4A8 GEMM. Additionally, we develop SmoothAttention to effectively mitigate the accuracy degradation incurred by 4-bit KV quantization. In the QServe system, we perform compute-aware weight reordering and take advantage of register-level parallelism to reduce dequantization latency. We also make fused attention memory-bound, harnessing the performance gain brought by KV4 quantization. As a result, QServe improves the maximum achievable serving throughput of Llama-3-8B by **1.2×** on A100, **1.4×** on L40S; and Qwen1.5-72B by **2.4×** on A100, **3.5×** on L40S, compared to TensorRT-LLM.


#### Perplexity Evaluation
Below is the WikiText2 perplexity evaluated with 2048 sequence length. The lower is the better.
| Methods | Precision | Llama-3.1 70B | Llama-3.1 8B | Llama-3 70B | Llama-3 8B | Llama-2 7B | Llama-2 13B | Llama-2 70B | Llama 7B | Llama 13B | Llama 30B | Mistral 7B | Yi 34B |
|-------------|--------------|---------------|--------------|-------------| ------------|------------|-------------|-------------|----------|-----------|-----------|------------|--------|
| FP16 | | 2.81 | 6.24 | 2.85 | 6.14 | 5.47 | 4.88 | 3.32 | 5.68 | 5.09 | 4.10 | 5.25 | 4.60 |
| SmoothQuant | W8A8 | 3.23 | 6.38 | 3.14 | 6.28 | 5.54 | 4.95 | 3.36 | 5.73 | 5.13 | 4.23 | 5.29 | 4.69 |
| GPTQ-R | W4A16 g128 | 3.46 | 6.64 | 3.42 | 6.56 | 5.63 | 4.99 | 3.43 | 5.83 | 5.20 | 4.22 | 5.39 | 4.68 |
| AWQ | W4A16 g128 | 3.22 | 6.60 | 3.20 | 6.54 | 5.60 | 4.97 | 3.41 | 5.78 | 5.19 | 4.21 | 5.37 | 4.67 |
| QuaRot | W4A4 | 5.97 | 8.32 | 6.75 | 8.33 | 6.19 | 5.45 | 3.83 | 6.34 | 5.58 | 4.64 | 5.77 | - |
| SpinQuant | W4A4 | 4.80 | 7.42 | 6.27 | 7.37 | 5.96 | 5.24 | 3.71 | 6.14 | 5.39 | 4.56 | - | - |
| Atom | W4A4 g128 | - | - | 4.33 | 7.78 | 6.12 | 5.31 | 3.73 | 6.25 | 5.52 | 4.61 | 5.76 | 4.97 |
| QoQ | W4A8KV4 | 3.68 | 6.87 | 3.65 | 6.81 | 5.75 | 5.11 | 3.50 | 5.92 | 5.27 | 4.31 | 5.44 | 4.73 |
| QoQ | W4A8KV4 g128 | 3.51 | 6.77 | 3.50 | 6.70 | 5.67 | 5.06 | 3.46 | 5.88 | 5.23 | 4.27 | 5.41 | 4.73 |
\* SmoothQuant is evaluated with per-tensor static KV cache quantization.
#### Efficiency Benchmarks
When serving the large language models Llama-3-8B and Qwen1.5-72B on L40S and A100 GPUs, QServe demonstrates superior performance, achieving **1.2x-1.4x higher throughput** compared to the leading industry solution, TensorRT-LLM, for Llama-3-8B, and a **2.4x-3.5x higher throughput** for Qwen1.5-72B.
See more about benchmarking setting in [QServe GPU Inference System](https://github.com/mit-han-lab/qserve).
| L40S (48G) | Llama-3-8B | Llama-2-7B | Mistral-7B | Llama-2-13B | Llama-30B | Yi-34B | Llama-2-70B | Qwen-1.5-72B |
|----------------------|------------|------------|------------|-------------|-----------|-----------|-------------|--------------|
| TRT-LLM-FP16 | 1326 | 444 | 1566 | 92 | OOM | OOM | OOM | OOM |
| TRT-LLM-W4A16 | 1431 | 681 | 1457 | 368 | 148 | 313 | 119 | 17 |
| TRT-LLM-W8A8 | 2634 | 1271 | 2569 | 440 | 123 | 364 | OOM | OOM |
| Atom-W4A4 | -- | 2120 | -- | -- | -- | -- | -- | -- |
| QuaRot-W4A4 | -- | 805 | -- | 413 | 133 | -- | -- | 15 |
| QServe-W4A8KV4 | **3656** | **2394** | **3774** | **1327** | **504** | **869** | **286** | **59** |
| Throughput Increase* | **1.39x** | **1.13x** | **1.47x** | **3.02x** | **3.41x** | **2.39x** | **2.40x** | **3.47x** |
| A100 (80G) | Llama-3-8B | Llama-2-7B | Mistral-7B | Llama-2-13B | Llama-30B | Yi-34B | Llama-2-70B | Qwen-1.5-72B |
|----------------------|------------| -----------|------------|-------------|-----------|-----------|-------------|--------------|
| TRT-LLM-FP16 | 2503 | 1549 | 2371 | 488 | 80 | 145 | OOM | OOM |
| TRT-LLM-W4A16 | 2370 | 1549 | 2403 | 871 | 352 | 569 | 358 | 143 |
| TRT-LLM-W8A8 | 2396 | 2334 | 2427 | 1277 | 361 | 649 | 235 | 53 |
| Atom-W4A4 | -- | 1160 | -- | -- | -- | -- | -- | -- |
| QuaRot-W4A4 | -- | 1370 | -- | 289 | 267 | -- | -- | 68 |
| QServe-W4A8KV4 | **3005** | **2908** | **2970** | **1741** | **749** | **803** | **419** | **340** |
| Throughput Increase* | **1.20x** | **1.25x** | **1.22x** | **1.36x** | **2.07x** | **1.23x** | **1.17x** | **2.38x** |
The absolute token generation throughputs of QServe and baseline systems (Unit: tokens/second. `--` means unsupported). All experiments were conducted under the same device memory budget. Throughput increase of QServe is calculated with regard to the best baseline in each column.
## Reference
If you find `deepcompressor` useful or relevant to your research, please kindly cite our paper:
```bibtex
@article{lin2024qserve,
title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and Han, Song},
journal={arXiv preprint arXiv:2405.04532},
year={2024}
}
@article{
li2024svdquant,
title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
journal={arXiv preprint arXiv:2411.05007},
year={2024}
}
```
## Related Projects
The following projects are highly related to QServe. Our group has developed full-stack application-algorithm-system-hardware support for efficient large models, receiving **9k+ GitHub stars** and **over 1M Huggingface community downloads**.
You are also welcome to check out [MIT HAN Lab](https://hanlab.mit.edu) for other exciting projects on **Efficient Generative AI**!
- [**System**] [QServe: W4A8KV4 Quantization for Efficient LLM Serving](https://github.com/mit-han-lab/qserve)
- [**System**] [TinyChat: Efficient and Lightweight Chatbot with AWQ](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat)
- [**Application**] [VILA: On Pretraining of Visual-Language Models](https://github.com/Efficient-Large-Model/VILA)
- [**Algorithm**] [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://github.com/mit-han-lab/llm-awq)
- [**Algorithm**] [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://github.com/mit-han-lab/smoothquant)
- [**Algorithm**] [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://github.com/mit-han-lab/distrifuser)
- [**Hardware**] [SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning](https://arxiv.org/abs/2012.09852)
## Acknowledgments
DeepCompressor is inspired by many open-source libraries, including (but not limited to) [GPTQ](https://arxiv.org/abs/2210.17323), [QuaRot](https://arxiv.org/abs/2404.00456) and [Atom](https://arxiv.org/abs/2310.19102).
================================================
FILE: assets/diffusion/.gitkeep
================================================
================================================
FILE: assets/llm/.gitkeep
================================================
================================================
FILE: deepcompressor/__init__.py
================================================
from .version import __version__ # noqa: F401
================================================
FILE: deepcompressor/app/__init__.py
================================================
================================================
FILE: deepcompressor/app/diffusion/__init__.py
================================================
================================================
FILE: deepcompressor/app/diffusion/cache/__init__.py
================================================
from .config import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig
================================================
FILE: deepcompressor/app/diffusion/cache/config.py
================================================
# -*- coding: utf-8 -*-
"""LLM quantization cache configuration."""
import functools
import re
import typing as tp
from dataclasses import dataclass, field
from omniconfig import configclass
from deepcompressor.utils.config.path import BasePathConfig
from ..nn.struct import DiffusionModelStruct
__all__ = ["DiffusionQuantCacheConfig", "DiffusionPtqCacheConfig"]
@dataclass
class DiffusionQuantCacheConfig(BasePathConfig):
"""Denoising diffusion model quantization cache path.
Args:
smooth (`str`, *optional*, default=`""`):
The smoothing scales cache path.
branch (`str`, *optional*, default=`""`):
The low-rank branches cache path.
wgts (`str`, *optional*, default=`""`):
The weight quantizers state dict cache path.
acts (`str`, *optional*, default=`""`):
The activation quantizers state dict cache path
"""
smooth: str = ""
branch: str = ""
wgts: str = ""
acts: str = ""
@staticmethod
def simplify_path(path: str, key_map: dict[str, set[str]]) -> str:
"""Simplify the cache path."""
to_replace = {}
# we first extract all the parts matching the pattern "(skip|include).\[[a-zA-Z0-9_\+]+\]"
for part in re.finditer(r"(skip|include)\.\[[a-zA-Z0-9_\+]+\]", path):
# remove the "skip." or "include." prefix
part = part.group(0)
if part[0] == "s":
prefix, keys = part[:4], part[6:-1]
else:
prefix, keys = part[:7], part[9:-1]
# simplify the keys
keys = "+".join(
(
"".join((s[0] for s in x.split("_")))
for x in DiffusionModelStruct._simplify_keys(keys.split("+"), key_map=key_map)
)
)
to_replace[part] = f"{prefix}.[{keys}]"
# we then replace the parts
for key, value in to_replace.items():
path = path.replace(key, value)
return path
def simplify(self, key_map: dict[str, set[str]]) -> tp.Self:
"""Simplify the cache paths."""
return self.apply(functools.partial(self.simplify_path, key_map=key_map))
@configclass
@dataclass
class DiffusionPtqCacheConfig:
root: str
dirpath: DiffusionQuantCacheConfig = field(init=False)
path: DiffusionQuantCacheConfig = field(init=False)
================================================
FILE: deepcompressor/app/diffusion/config.py
================================================
# -*- coding: utf-8 -*-
"""Top-level config of post-training quantization for a diffusion model."""
import os
from dataclasses import dataclass, field
import diffusers.training_utils
import omniconfig
import torch
from omniconfig import ConfigParser, configclass
from deepcompressor.app.llm.config import LlmCacheConfig, LlmQuantConfig
from deepcompressor.data.utils import ScaleUtils
from deepcompressor.utils.config.output import OutputConfig
from .cache import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig
from .eval import DiffusionEvalConfig
from .nn.struct import DiffusionModelStruct
from .pipeline import DiffusionPipelineConfig
from .quant import DiffusionQuantConfig
__all__ = [
"DiffusionPtqRunConfig",
"DiffusionPtqCacheConfig",
"DiffusionQuantCacheConfig",
"DiffusionEvalConfig",
"DiffusionPipelineConfig",
"DiffusionQuantConfig",
]
@configclass
@dataclass
class DiffusionPtqRunConfig:
"""Top-level config of post-training quantization for a diffusion model.
Args:
cache (`DiffusionPtqCacheConfig`):
The cache configuration.
output (`OutputConfig`):
The output directory configuration.
pipeline (`DiffusionPipelineConfig`):
The diffusion pipeline configuration
eval (`DiffusionEvalConfig`):
The evaluation configuration.
quant (`DiffusionQuantConfig`):
The post-training quantization configuration.
seed (`int`, *optional*, defaults to `12345`):
The seed for reproducibility.
skip_gen (`bool`, *optional*, defaults to `False`):
Whether to skip generation.
skip_eval (`bool`, *optional*, defaults to `False`):
Whether to skip evaluation.
load_model (`str`, *optional*, defaults to `""`):
Directory path to load the model checkpoint.
save_model (`str`, *optional*, defaults to `""`):
Directory path to save the model checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the quantization cache on save.
"""
cache: DiffusionPtqCacheConfig | None
output: OutputConfig
pipeline: DiffusionPipelineConfig
eval: DiffusionEvalConfig
quant: DiffusionQuantConfig = field(metadata={omniconfig.ARGPARSE_KWARGS: {"prefix": ""}})
text: LlmQuantConfig | None = None
text_cache: LlmCacheConfig = field(default_factory=LlmCacheConfig)
seed: int = 12345
skip_gen: bool = False
skip_eval: bool = False
load_from: str = ""
save_model: str = ""
copy_on_save: bool = False
def __post_init__(self):
# region set text encoder quanatization scale default dtype
if self.text is not None and self.text.enabled_wgts:
self.text.wgts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.text.wgts.scale_dtypes, default_dtype=self.pipeline.dtype)
)
if self.text is not None and self.text.enabled_ipts:
self.text.ipts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.text.ipts.scale_dtypes, default_dtype=self.pipeline.dtype)
)
if self.text is not None and self.text.enabled_opts:
self.text.opts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.text.opts.scale_dtypes, default_dtype=self.pipeline.dtype)
)
# endregion
self.eval.num_gpus = min(torch.cuda.device_count(), self.eval.num_gpus)
if self.eval.batch_size_per_gpu is None:
self.eval.batch_size_per_gpu = max(1, self.eval.batch_size // self.eval.num_gpus)
self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus
else:
self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus
# region setup calib dataset path
self.quant.calib.path = self.quant.calib.path.format(
dtype=self.pipeline.dtype,
family=self.pipeline.family,
model=self.pipeline.name,
protocol=self.eval.protocol,
data=self.quant.calib.data,
)
if self.quant.calib.path:
self.quant.calib.path = os.path.abspath(os.path.expanduser(self.quant.calib.path))
# endregion
# region setup eval reference root
self.eval.ref_root = self.eval.ref_root.format(
dtype=self.pipeline.dtype,
family=self.pipeline.family,
model=self.pipeline.name,
protocol=self.eval.protocol,
)
if self.eval.ref_root:
self.eval.ref_root = os.path.abspath(os.path.expanduser(self.eval.ref_root))
# endregion
# region setup cache directory
if self.cache is not None:
if self.quant.enabled_wgts or self.quant.enabled_ipts or self.quant.enabled_opts:
self.cache.dirpath = self.quant.generate_cache_dirpath(
root=self.cache.root, shift=self.pipeline.shift_activations, default_dtype=self.pipeline.dtype
)
self.cache.path = self.cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt")
else:
self.cache.dirpath = self.cache.path = None
if self.text is not None and self.text.is_enabled():
if not self.text_cache.root:
self.text_cache.root = os.path.join(self.cache.root, "diffusion")
self.text_cache.dirpath = self.text.generate_cache_dirpath(root=self.text_cache.root, seed=self.seed)
self.text_cache.path = self.text_cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt")
# endregion
# region setup output directory
if self.output.dirname == "reference":
assert self.eval.ref_root
self.output.job = f"run-{self.eval.num_samples}"
self.output.dirpath = self.eval.ref_root
self.eval.ref_root = ""
self.eval.gen_root = "{output}"
else:
if self.output.dirname == "default":
self.output.dirname = self.generate_default_dirname()
calib_dirname = self.quant.generate_calib_dirname() or "-"
self.output.dirpath = os.path.join(
self.output.root,
"diffusion",
self.pipeline.family,
self.pipeline.name,
*self.quant.generate_dirnames(default_dtype=self.pipeline.dtype)[:-1],
calib_dirname,
self.output.dirname,
)
if (self.eval.chunk_start > 0 or self.eval.chunk_step > 1) and not self.eval.chunk_only:
self.output.job += f".c{self.eval.chunk_start}.{self.eval.chunk_step}"
# endregion
diffusers.training_utils.set_seed(self.seed)
def generate_default_dirname(self) -> str:
name = "-shift" if self.pipeline.shift_activations else ""
if self.quant.is_enabled():
name += f"-{self.quant.generate_default_dirname()}"
if self.text is not None and self.text.is_enabled():
name += f"-text-{self.text.generate_default_dirname()}"
size_name = ""
if self.eval.height:
size_name += f".h{self.eval.height}"
if self.eval.width:
size_name += f".w{self.eval.width}"
if size_name:
name += f"-{size_name[1:]}"
sampling_name = ""
if self.eval.num_steps is not None:
sampling_name += f".t{self.eval.num_steps}"
if self.eval.guidance_scale is not None:
sampling_name += f".g{self.eval.guidance_scale}"
if sampling_name:
name += f"-{sampling_name[1:]}"
if self.eval.num_samples != -1:
name += f"-s{self.eval.num_samples}"
if self.eval.chunk_only:
name += f".c{self.eval.chunk_start}.{self.eval.chunk_step}"
assert name[0] == "-"
return name[1:]
@classmethod
def get_parser(cls) -> ConfigParser:
"""Get a parser for post-training quantization of a diffusion model.
Returns:
`ConfigParser`:
A parser for post-training quantization of a diffusion model.
"""
parser = ConfigParser("Diffusion Run configuration")
DiffusionQuantConfig.set_key_map(DiffusionModelStruct._get_default_key_map())
parser.add_config(cls)
return parser
================================================
FILE: deepcompressor/app/diffusion/dataset/__init__.py
================================================
# -*- coding: utf-8 -*-
from .base import DiffusionDataset
from .calib import DiffusionCalibCacheLoader, DiffusionCalibCacheLoaderConfig
================================================
FILE: deepcompressor/app/diffusion/dataset/base.py
================================================
# -*- coding: utf-8 -*-
"""Dataset for diffusion models."""
import os
import random
import typing as tp
import numpy as np
import torch
import torch.utils.data
from torch.nn import functional as F
from deepcompressor.utils.common import tree_collate, tree_map
__all__ = ["DiffusionDataset"]
class DiffusionDataset(torch.utils.data.Dataset):
path: str
filenames: list[str]
filepaths: list[str]
def __init__(self, path: str, num_samples: int = -1, seed: int = 0, ext: str = ".npy") -> None:
if os.path.exists(path):
self.path = path
if "caches" in os.listdir(path):
path = os.path.join(path, "caches")
filenames = [f for f in sorted(os.listdir(path)) if f.endswith(ext)]
if num_samples > 0 and num_samples < len(filenames):
random.Random(seed).shuffle(filenames)
filenames = filenames[:num_samples]
filenames = sorted(filenames)
self.filenames = filenames
self.filepaths = [os.path.join(path, f) for f in filenames]
else:
raise ValueError(f"Invalid data path: {path}")
def __len__(self) -> int:
return len(self.filepaths)
def __getitem__(self, idx) -> dict[str, tp.Any]:
data = np.load(self.filepaths[idx], allow_pickle=True).item()
if isinstance(data["input_args"][0], str):
name = data["input_args"][0]
latent = np.load(os.path.join(self.path, "latents", name))
data["input_args"][0] = latent
if isinstance(data["input_kwargs"]["encoder_hidden_states"], str):
name = data["input_kwargs"]["encoder_hidden_states"]
text_emb = np.load(os.path.join(self.path, "text_embs", name))
data["input_kwargs"]["encoder_hidden_states"] = text_emb
data = tree_map(lambda x: torch.from_numpy(x), data)
# Pad encoder_hidden_states to 300 for pixart
if "encoder_attention_mask" in data["input_kwargs"]:
encoder_attention_mask = data["input_kwargs"]["encoder_attention_mask"]
encoder_hidden_states = data["input_kwargs"]["encoder_hidden_states"]
encoder_hidden_states = F.pad(
encoder_hidden_states,
(0, 0, 0, encoder_attention_mask.shape[1] - encoder_hidden_states.shape[1]),
)
data["input_kwargs"]["encoder_hidden_states"] = encoder_hidden_states
return data
def build_loader(self, **kwargs) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(self, collate_fn=tree_collate, **kwargs)
================================================
FILE: deepcompressor/app/diffusion/dataset/calib.py
================================================
# -*- coding: utf-8 -*-
"""Calibration dataset for diffusion models."""
import random
import typing as tp
from collections import OrderedDict
from dataclasses import MISSING, dataclass
import torch
import torch.nn as nn
import torch.utils.data
from diffusers.models.attention import JointTransformerBlock
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import (
FluxSingleTransformerBlock,
FluxTransformerBlock,
)
from omniconfig import configclass
from deepcompressor.data.cache import (
IOTensorsCache,
ModuleForwardInput,
TensorCache,
TensorsCache,
)
from deepcompressor.data.utils.reshape import AttentionInputReshapeFn, LinearReshapeFn
from deepcompressor.dataset.action import CacheAction, ConcatCacheAction
from deepcompressor.dataset.cache import BaseCalibCacheLoader
from deepcompressor.dataset.config import BaseDataLoaderConfig
from ..nn.struct import DiffusionBlockStruct, DiffusionModelStruct
from .base import DiffusionDataset
__all__ = [
"DiffusionCalibCacheLoaderConfig",
"DiffusionCalibDataset",
"DiffusionConcatCacheAction",
"DiffusionCalibCacheLoader",
]
@configclass
@dataclass(kw_only=True)
class DiffusionCalibCacheLoaderConfig(BaseDataLoaderConfig):
"""Configuration for collecting calibration dataset for quantization.
Args:
data (`str`):
Dataset name.
num_samples (`int`):
Number of dataset samples.
batch_size (`int`):
Batch size when loading dataset.
path (`str`):
Path to the dataset directory.
num_workers (`int`):
Number of workers for data loading.
"""
path: str
num_workers: int = 8
def build_dataset(self) -> "DiffusionCalibDataset":
"""Build the calibration dataset."""
return DiffusionCalibDataset(self.path, num_samples=self.num_samples)
def build_loader(self) -> "DiffusionCalibCacheLoader":
"""Build the data loader."""
return DiffusionCalibCacheLoader(self)
class DiffusionCalibDataset(DiffusionDataset):
data: list[dict[str, tp.Any]]
def __init__(self, path: str, num_samples: int = -1, seed: int = 0) -> None:
super().__init__(path, num_samples=num_samples, seed=seed, ext=".pt")
data = [torch.load(path) for path in self.filepaths]
random.Random(seed).shuffle(data)
self.data = data
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx) -> dict[str, tp.Any]:
return self.data[idx]
class DiffusionConcatCacheAction(ConcatCacheAction):
def info(
self,
name: str,
module: nn.Module,
tensors: dict[int | str, torch.Tensor],
cache: TensorsCache,
) -> None:
"""Update cache information.
Args:
name (`str`):
Module name.
module (`nn.Module`):
Module.
tensors (`dict[int | str, torch.Tensor]`):
Tensors to cache.
cache (`TensorsCache`):
Cache.
"""
if isinstance(module, Attention):
encoder_hidden_states = tensors.get("encoder_hidden_states", None)
if encoder_hidden_states is None:
tensors.pop("encoder_hidden_states", None)
cache.tensors.pop("encoder_hidden_states", None)
else:
encoder_hidden_states_cache = cache.tensors["encoder_hidden_states"]
encoder_channels_dim = 1 if encoder_hidden_states.dim() == 4 else -1
if encoder_hidden_states_cache.channels_dim is None:
encoder_hidden_states_cache.channels_dim = encoder_channels_dim
if encoder_channels_dim == -1:
encoder_hidden_states_cache.reshape = LinearReshapeFn()
else:
encoder_hidden_states_cache.reshape = AttentionInputReshapeFn(encoder_channels_dim)
else:
assert encoder_hidden_states_cache.channels_dim == encoder_channels_dim
hidden_states, hidden_states_cache = tensors["hidden_states"], cache.tensors["hidden_states"]
channels_dim = 1 if hidden_states.dim() == 4 else -1
if hidden_states_cache.channels_dim is None:
hidden_states_cache.channels_dim = channels_dim
if channels_dim == -1:
hidden_states_cache.reshape = LinearReshapeFn()
else:
hidden_states_cache.reshape = AttentionInputReshapeFn(channels_dim)
else:
assert hidden_states_cache.channels_dim == channels_dim
return super().info(name, module, tensors, cache)
class DiffusionCalibCacheLoader(BaseCalibCacheLoader):
config: DiffusionCalibCacheLoaderConfig
dataset: DiffusionCalibDataset
def __init__(self, config: DiffusionCalibCacheLoaderConfig) -> None:
"""Initialize the cache for the diffusion calibration dataset.
Args:
config (`DiffusionCalibCacheLoaderConfig`):
Configuration for the calibration cache loader.
"""
super().__init__(dataset=config.build_dataset(), batch_size=config.batch_size)
self.batch_size = min(config.batch_size, len(self.dataset))
self.config = config
def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
"""Initialize cache.
Args:
name (`str`):
Module name.
module (`nn.Module`):
Module.
Returns:
`IOTensorsCache`:
Cache for inputs and outputs.
"""
if isinstance(module, FluxSingleTransformerBlock):
return IOTensorsCache(
inputs=TensorsCache(
OrderedDict(
hidden_states=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
temb=TensorCache(channels_dim=1, reshape=LinearReshapeFn()),
)
),
outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
)
elif isinstance(module, Attention):
return IOTensorsCache(
inputs=TensorsCache(
OrderedDict(
hidden_states=TensorCache(channels_dim=None, reshape=None),
encoder_hidden_states=TensorCache(channels_dim=None, reshape=None),
),
),
outputs=TensorCache(channels_dim=None, reshape=None),
)
else:
return super()._init_cache(name, module)
def iter_samples(self) -> tp.Generator[ModuleForwardInput, None, None]:
dataloader = self.dataset.build_loader(
batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=self.config.num_workers
)
for data in dataloader:
yield ModuleForwardInput(args=data["input_args"], kwargs=data["input_kwargs"])
def _convert_layer_inputs(
self, m: nn.Module, args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], save_all: bool = False
) -> ModuleForwardInput:
"""Convert layer inputs to module forward input.
Args:
m (`nn.Module`):
Layer.
args (`tuple[Any, ...]`):
Layer input arguments.
kwargs (`dict[str, Any]`):
Layer input keyword arguments.
save_all (`bool`, *optional*, defaults to `False`):
Whether to save all inputs.
Returns:
`ModuleForwardInput`:
Module forward input.
"""
kwargs = {k: v for k, v in kwargs.items()} # noqa: C416
if "res_hidden_states_tuple" in kwargs:
kwargs["res_hidden_states_tuple"] = None
if "hidden_states" in kwargs:
hidden_states = kwargs.pop("hidden_states")
assert len(args) == 0, f"Invalid args: {args}"
else:
hidden_states = args[0]
if isinstance(m, (FluxTransformerBlock, JointTransformerBlock)):
if "encoder_hidden_states" in kwargs:
encoder_hidden_states = kwargs.pop("encoder_hidden_states")
else:
encoder_hidden_states = args[1]
return ModuleForwardInput(
args=[
hidden_states.detach().cpu() if save_all else MISSING,
encoder_hidden_states.detach().cpu() if save_all else MISSING,
],
kwargs=kwargs,
)
else:
return ModuleForwardInput(
args=[hidden_states.detach().cpu() if save_all else MISSING, *args[1:]], kwargs=kwargs
)
def _convert_layer_outputs(self, m: nn.Module, outputs: tp.Any) -> dict[str | int, tp.Any]:
"""Convert layer outputs to dictionary for updating the next layer inputs.
Args:
m (`nn.Module`):
Layer.
outputs (`Any`):
Layer outputs.
Returns:
`dict[str | int, Any]`:
Dictionary for updating the next layer inputs.
"""
if isinstance(m, (FluxTransformerBlock, JointTransformerBlock)):
assert isinstance(outputs, tuple) and len(outputs) == 2
encoder_hidden_states, hidden_states = outputs
return {0: hidden_states.detach().cpu(), 1: encoder_hidden_states.detach().cpu()}
else:
return super()._convert_layer_outputs(m, outputs)
def iter_layer_activations( # noqa: C901
self,
model: nn.Module | DiffusionModelStruct,
*args,
needs_inputs_fn: tp.Callable[[str, nn.Module], bool],
needs_outputs_fn: tp.Callable[[str, nn.Module], bool] | None = None,
action: CacheAction | None = None,
skip_pre_modules: bool = True,
skip_post_modules: bool = True,
**kwargs,
) -> tp.Generator[
tuple[
str,
tuple[
DiffusionBlockStruct | nn.Module,
dict[str, IOTensorsCache],
dict[str, tp.Any],
],
],
None,
None,
]:
"""Iterate over model activations in layers.
Args:
model (`nn.Module`):
Model.
action (`CacheAction`):
Action for caching activations.
needs_inputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `True`):
Function for determining whether to cache inputs for a module given its name and itself.
needs_outputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `None`):
Function for determining whether to cache outputs for a module given its name and itself.
*args: Arguments for ``iter_samples``.
**kwargs: Keyword arguments for ``iter_samples``.
Yields:
Generator[
tuple[str, tuple[DiffusionBlockStruct | nn.Module, dict[str, IOTensorsCache], dict[str, tp.Any]]],
None,
None
]:
Generator of tuple of
- layer name
- a tuple of
- layer itself
- inputs and outputs cache of each module in the layer
- layer input arguments
"""
if not isinstance(model, DiffusionModelStruct):
model_struct = DiffusionModelStruct.construct(model)
else:
model_struct = model
model = model_struct.module
assert isinstance(model_struct, DiffusionModelStruct)
assert isinstance(model, nn.Module)
action = DiffusionConcatCacheAction("cpu") if action is None else action
layers, layer_structs, recomputes, use_prev_layer_outputs = model_struct.get_iter_layer_activations_args(
skip_pre_modules=skip_pre_modules,
skip_post_modules=skip_post_modules,
**self.dataset[0]["input_kwargs"],
)
for layer_idx, (layer_name, (layer, layer_cache, layer_inputs)) in enumerate(
self._iter_layer_activations(
model,
*args,
action=action,
layers=layers,
needs_inputs_fn=needs_inputs_fn,
needs_outputs_fn=needs_outputs_fn,
recomputes=recomputes,
use_prev_layer_outputs=use_prev_layer_outputs,
**kwargs,
)
):
layer_kwargs = {k: v for k, v in layer_inputs[0].kwargs.items()} # noqa: C416
layer_kwargs.pop("hidden_states", None)
layer_kwargs.pop("encoder_hidden_states", None)
layer_kwargs.pop("temb", None)
layer_struct = layer_structs[layer_idx]
if isinstance(layer_struct, DiffusionBlockStruct):
assert layer_struct.name == layer_name
assert layer is layer_struct.module
for transformer_block_struct in layer_struct.iter_transformer_block_structs():
for attn_struct in transformer_block_struct.iter_attention_structs():
if attn_struct.q_proj_name in layer_cache:
if not attn_struct.is_cross_attn():
cache = layer_cache[attn_struct.q_proj_name]
layer_cache[attn_struct.k_proj_name] = cache
layer_cache[attn_struct.v_proj_name] = cache
if attn_struct.add_k_proj_name in layer_cache:
assert not attn_struct.is_self_attn()
cache = layer_cache[attn_struct.add_k_proj_name]
layer_cache[attn_struct.add_v_proj_name] = cache
if attn_struct.is_joint_attn():
layer_cache[attn_struct.add_q_proj_name] = cache
ffn_struct = transformer_block_struct.ffn_struct
num_experts = ffn_struct.config.num_experts
if ffn_struct is not None and num_experts > 1:
for expert_idx in range(num_experts):
if ffn_struct.up_proj_names[expert_idx] in layer_cache:
cache = layer_cache[ffn_struct.up_proj_names[expert_idx]]
for up_proj_name in ffn_struct.up_proj_names[expert_idx::num_experts]:
layer_cache[up_proj_name] = cache
if ffn_struct.down_proj_names[expert_idx] in layer_cache:
cache = layer_cache[ffn_struct.down_proj_names[expert_idx]]
for down_proj_name in ffn_struct.down_proj_names[expert_idx::num_experts]:
layer_cache[down_proj_name] = cache
yield layer_name, (layer_struct, layer_cache, layer_kwargs)
================================================
FILE: deepcompressor/app/diffusion/dataset/collect/calib.py
================================================
# -*- coding: utf-8 -*-
"""Collect calibration dataset."""
import os
from dataclasses import dataclass
import datasets
import torch
from omniconfig import configclass
from torch import nn
from tqdm import tqdm
from deepcompressor.app.diffusion.config import DiffusionPtqRunConfig
from deepcompressor.utils.common import hash_str_to_int, tree_map
from ...utils import get_control
from ..data import get_dataset
from .utils import CollectHook
def process(x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
return torch.from_numpy(x.float().numpy()).to(dtype)
def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
samples_dirpath = os.path.join(config.output.root, "samples")
caches_dirpath = os.path.join(config.output.root, "caches")
os.makedirs(samples_dirpath, exist_ok=True)
os.makedirs(caches_dirpath, exist_ok=True)
caches = []
pipeline = config.pipeline.build()
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
assert isinstance(model, nn.Module)
model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
batch_size = config.eval.batch_size
print(f"In total {len(dataset)} samples")
print(f"Evaluating with batch size {batch_size}")
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for batch in tqdm(
dataset.iter(batch_size=batch_size, drop_last_batch=False),
desc="Data",
leave=False,
dynamic_ncols=True,
total=(len(dataset) + batch_size - 1) // batch_size,
):
filenames = batch["filename"]
prompts = batch["prompt"]
seeds = [hash_str_to_int(name) for name in filenames]
generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
pipeline_kwargs = config.eval.get_pipeline_kwargs()
task = config.pipeline.task
control_root = config.eval.control_root
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
controls = get_control(
task,
batch["image"],
names=batch["filename"],
data_root=os.path.join(
control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
),
)
if task == "inpainting":
pipeline_kwargs["image"] = controls[0]
pipeline_kwargs["mask_image"] = controls[1]
else:
pipeline_kwargs["control_image"] = controls
result_images = pipeline(prompts, generator=generators, **pipeline_kwargs).images
num_guidances = (len(caches) // batch_size) // config.eval.num_steps
num_steps = len(caches) // (batch_size * num_guidances)
assert (
len(caches) == batch_size * num_steps * num_guidances
), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
image.save(os.path.join(samples_dirpath, f"{filename}.png"))
for s in range(num_steps):
for g in range(num_guidances):
c = caches[s * batch_size * num_guidances + g * batch_size + j]
c["filename"] = filename
c["step"] = s
c["guidance"] = g
c = tree_map(lambda x: process(x), c)
torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
caches.clear()
@configclass
@dataclass
class CollectConfig:
"""Configuration for collecting calibration dataset.
Args:
root (`str`, *optional*, defaults to `"datasets"`):
Root directory to save the collected dataset.
dataset_name (`str`, *optional*, defaults to `"qdiff"`):
Name of the collected dataset.
prompt_path (`str`, *optional*, defaults to `"prompts/qdiff.yaml"`):
Path to the prompt file.
num_samples (`int`, *optional*, defaults to `128`):
Number of samples to collect.
"""
root: str = "datasets"
dataset_name: str = "qdiff"
data_path: str = "prompts/qdiff.yaml"
num_samples: int = 128
if __name__ == "__main__":
parser = DiffusionPtqRunConfig.get_parser()
parser.add_config(CollectConfig, scope="collect", prefix="collect")
configs, _, unused_cfgs, unused_args, unknown_args = parser.parse_known_args()
ptq_config, collect_config = configs[""], configs["collect"]
assert isinstance(ptq_config, DiffusionPtqRunConfig)
assert isinstance(collect_config, CollectConfig)
if len(unused_cfgs) > 0:
print(f"Warning: unused configurations {unused_cfgs}")
if unused_args is not None:
print(f"Warning: unused arguments {unused_args}")
assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
collect_dirpath = os.path.join(
collect_config.root,
str(ptq_config.pipeline.dtype),
ptq_config.pipeline.name,
ptq_config.eval.protocol,
collect_config.dataset_name,
f"s{collect_config.num_samples}",
)
print(f"Saving caches to {collect_dirpath}")
dataset = get_dataset(
collect_config.data_path,
max_dataset_size=collect_config.num_samples,
return_gt=ptq_config.pipeline.task in ["canny-to-image"],
repeat=1,
)
ptq_config.output.root = collect_dirpath
os.makedirs(ptq_config.output.root, exist_ok=True)
collect(ptq_config, dataset=dataset)
================================================
FILE: deepcompressor/app/diffusion/dataset/collect/utils.py
================================================
# -*- coding: utf-8 -*-
"""Common utilities for collecting data."""
import inspect
import typing as tp
import torch
import torch.nn as nn
from diffusers.models.transformers import (
FluxTransformer2DModel,
PixArtTransformer2DModel,
SanaTransformer2DModel,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from deepcompressor.utils.common import tree_map, tree_split
__all__ = ["CollectHook"]
class CollectHook:
def __init__(self, caches: list[dict[str, tp.Any]] = None, zero_redundancy: bool = False) -> None:
self.caches = [] if caches is None else caches
self.zero_redundancy = zero_redundancy
def __call__(
self,
module: nn.Module,
input_args: tuple[torch.Tensor, ...],
input_kwargs: dict[str, tp.Any],
output: tuple[torch.Tensor, ...],
) -> tp.Any:
new_args = []
signature = inspect.signature(module.forward)
bound_arguments = signature.bind(*input_args, **input_kwargs)
arguments = bound_arguments.arguments
args_to_kwargs = {k: v for k, v in arguments.items() if k not in input_kwargs}
input_kwargs.update(args_to_kwargs)
if isinstance(module, UNet2DConditionModel):
sample = input_kwargs.pop("sample")
new_args.append(sample)
timestep = input_kwargs["timestep"]
timesteps = timestep
if not torch.is_tensor(timesteps):
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
input_kwargs["timestep"] = timesteps
elif isinstance(module, (PixArtTransformer2DModel, SanaTransformer2DModel)):
new_args.append(input_kwargs.pop("hidden_states"))
elif isinstance(module, FluxTransformer2DModel):
new_args.append(input_kwargs.pop("hidden_states"))
else:
raise ValueError(f"Unknown model: {module}")
cache = tree_map(lambda x: x.cpu(), {"input_args": new_args, "input_kwargs": input_kwargs, "outputs": output})
split_cache = tree_split(cache)
if isinstance(module, PixArtTransformer2DModel) and self.zero_redundancy:
for cache in split_cache:
cache_kwargs = cache["input_kwargs"]
encoder_hidden_states = cache_kwargs.pop("encoder_hidden_states")
assert encoder_hidden_states.shape[0] == 1
encoder_attention_mask = cache_kwargs.get("encoder_attention_mask", None)
if encoder_attention_mask is not None:
encoder_hidden_states = encoder_hidden_states[:, : max(encoder_attention_mask.sum(), 1)]
cache_kwargs["encoder_hidden_states"] = encoder_hidden_states
self.caches.extend(split_cache)
================================================
FILE: deepcompressor/app/diffusion/dataset/data/COCO/COCO.py
================================================
# coding=utf-8
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO"""
import json
import os
import random
from pathlib import Path
import datasets
from PIL import Image
_CITATION = """
@article{DBLP:journals/corr/LinMBHPRDZ14,
author = {Tsung{-}Yi Lin and
Michael Maire and
Serge J. Belongie and
Lubomir D. Bourdev and
Ross B. Girshick and
James Hays and
Pietro Perona and
Deva Ramanan and
Piotr Doll{\'{a}}r and
C. Lawrence Zitnick},
title = {Microsoft {COCO:} Common Objects in Context},
journal = {CoRR},
volume = {abs/1405.0312},
year = {2014},
url = {http://arxiv.org/abs/1405.0312},
eprinttype = {arXiv},
eprint = {1405.0312},
timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
biburl = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""
_DESCRIPTION = """
MS COCO is a large-scale object detection, segmentation, and captioning dataset.
COCO has several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
5 captions per image, 250,000 people with keypoints.
"""
_HOMEPAGE = "https://cocodataset.org/#home"
_LICENSE = "CC BY 4.0"
_IMAGES_URLS = {
"train": "http://images.cocodataset.org/zips/train2014.zip",
"validation": "http://images.cocodataset.org/zips/val2014.zip",
}
_KARPATHY_FILES_URL = "https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip"
_FEATURES = datasets.Features(
{
"filepath": datasets.Value("string"),
"filename": datasets.Value("string"),
"image": datasets.Image(),
"image_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"prompt": datasets.Value("string"),
"prompt_id": datasets.Value("int32"),
"imgid": datasets.Value("int32"),
"split": datasets.Value("string"),
"cocoid": datasets.Value("int32"),
"sentences_raw": [datasets.Value("string")],
"sentids": [datasets.Value("int32")],
"sentences_sentid": [datasets.Value("int32")],
"sentences_tokens": [[datasets.Value("string")]],
}
)
def hash_string_to_int(s: str) -> int:
modulus = 10**9 + 7 # Large prime modulus
hash_int = 0
for char in s:
hash_int = (hash_int * 31 + ord(char)) % modulus
return hash_int
class COCOConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(COCOConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class COCO(datasets.GeneratorBasedBuilder):
"""COCO"""
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = COCOConfig
BUILDER_CONFIGS = [
COCOConfig(name="COCO_val", version=VERSION, description="COCO validation prompt set"),
COCOConfig(name="COCO_train", version=VERSION, description="COCO train prompt set"),
COCOConfig(name="COCO_full", version=VERSION, description="COCO full prompt set"),
]
DEFAULT_CONFIG_NAME = "COCO_val"
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=_FEATURES,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
annotation_file = os.path.join(dl_manager.download_and_extract(_KARPATHY_FILES_URL), "dataset_coco.json")
image_folders = {k: Path(v) for k, v in dl_manager.download_and_extract(_IMAGES_URLS).items()}
if self.config.name == "COCO_full":
split_keys = ["validation", "train"]
else:
split_keys = [self.config.name.split("_")[-1]]
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"annotation_file": annotation_file,
"image_folders": image_folders,
"split_keys": split_keys,
},
),
]
def _generate_examples(
self, annotation_file: str, image_folders: dict[str, str], split_keys: list[str] | tuple[str, ...]
):
with open(annotation_file, "r", encoding="utf-8") as fi:
annotations = json.load(fi)
metas = []
for split_key in split_keys:
for image_metadata in annotations["images"]:
if split_key == "train":
if image_metadata["split"] != "train" and image_metadata["split"] != "restval":
continue
elif split_key == "val":
if image_metadata["split"] != "val":
continue
elif split_key == "test":
if image_metadata["split"] != "test":
continue
metas.append(image_metadata)
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(metas)
metas = metas[: self.config.max_dataset_size]
metas = sorted(metas, key=lambda x: x["filename"])
for i, meta in enumerate(metas):
if "val2014" in meta["filename"]:
image_root = os.path.join(image_folders["validation"], "val2014")
else:
image_root = os.path.join(image_folders["train"], "train2014")
filename = meta["filename"].replace(".jpg", "").replace(".png", "")
image_path = os.path.join(image_root, filename + ".jpg")
sentences_raw = [caption["raw"] for caption in meta["sentences"]]
prompt_id = hash_string_to_int(filename) % len(sentences_raw)
prompt = sentences_raw[prompt_id]
yield (
i,
{
"filename": filename,
"image": Image.open(image_path) if self.config.return_gt else None,
"image_path": image_path,
"image_root": image_root,
"prompt": prompt,
"prompt_id": prompt_id,
"imgid": meta["imgid"],
"split": self.config.name,
"coco_id": meta["cocoid"],
"sentences_raw": sentences_raw,
"sentids": meta["sentids"],
"sentences_sentid": [caption["sentid"] for caption in meta["sentences"]],
"sentences_tokens": [caption["tokens"] for caption in meta["sentences"]],
},
)
================================================
FILE: deepcompressor/app/diffusion/dataset/data/COCO/__init__.py
================================================
================================================
FILE: deepcompressor/app/diffusion/dataset/data/DCI/DCI.py
================================================
import os
import random
import datasets
import yaml
from PIL import Image
_CITATION = """\
@InProceedings{Urbanek_2024_CVPR,
author = {Urbanek, Jack and Bordes, Florian and Astolfi, Pietro and Williamson, Mary and Sharma, Vasu and Romero-Soriano, Adriana},
title = {A Picture is Worth More Than 77 Text Tokens: Evaluating CLIP-Style Models on Dense Captions},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2024},
pages = {26700-26709}
}
""" # noqa: E501
_DESCRIPTION = """\
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
each with a complete description aiming to capture the full visual detail of what is present in the image.
Much of the description is directly aligned to submasks of the image.
"""
_HOMEPAGE = "https://github.com/facebookresearch/DCI"
_LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"
IMAGE_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.gz"
PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"}
class DCIConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(DCIConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = DCIConfig
BUILDER_CONFIGS = [DCIConfig(name="sDCI", version=VERSION, description="sDCI full prompt set")]
DEFAULT_CONFIG_NAME = "sDCI"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"meta_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
image_url = IMAGE_URL
meta_url = PROMPT_URLS[self.config.name]
meta_path = dl_manager.download(meta_url)
image_root = dl_manager.download_and_extract(image_url)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
)
]
def _generate_examples(self, meta_path: str, image_root: str):
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
prompt = meta[name]
image_path = os.path.join(image_root, f"{name}.jpg")
yield (
i,
{
"filename": name,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
},
)
================================================
FILE: deepcompressor/app/diffusion/dataset/data/DCI/__init__.py
================================================
================================================
FILE: deepcompressor/app/diffusion/dataset/data/MJHQ/MJHQ.py
================================================
import json
import os
import random
import datasets
from PIL import Image
_CITATION = """\
@misc{li2024playground,
title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
year={2024},
eprint={2402.17245},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
"""
_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""
_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"
_LICENSE = (
"Playground v2.5 Community License "
"(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)
IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"
META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"
class MJHQConfig(datasets.BuilderConfig):
def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
super(MJHQConfig, self).__init__(
name=kwargs.get("name", "default"),
version=kwargs.get("version", "0.0.0"),
data_dir=kwargs.get("data_dir", None),
data_files=kwargs.get("data_files", None),
description=kwargs.get("description", None),
)
self.max_dataset_size = max_dataset_size
self.return_gt = return_gt
class DCI(datasets.GeneratorBasedBuilder):
VERSION = datasets.Version("0.0.0")
BUILDER_CONFIG_CLASS = MJHQConfig
BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
DEFAULT_CONFIG_NAME = "MJHQ"
def _info(self):
features = datasets.Features(
{
"filename": datasets.Value("string"),
"category": datasets.Value("string"),
"image": datasets.Image(),
"prompt": datasets.Value("string"),
"prompt_path": datasets.Value("string"),
"image_root": datasets.Value("string"),
"image_path": datasets.Value("string"),
"split": datasets.Value("string"),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
)
def _split_generators(self, dl_manager: datasets.download.DownloadManager):
meta_path = dl_manager.download(META_URL)
image_root = dl_manager.download_and_extract(IMAGE_URL)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
),
]
def _generate_examples(self, meta_path: str, image_root: str):
with open(meta_path, "r") as f:
meta = json.load(f)
names = list(meta.keys())
if self.config.max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[: self.config.max_dataset_size]
names = sorted(names)
for i, name in enumerate(names):
category = meta[name]["category"]
prompt = meta[name]["prompt"]
image_path = os.path.join(image_root, category, f"{name}.jpg")
yield (
i,
{
"filename": name,
"category": category,
"image": Image.open(image_path) if self.config.return_gt else None,
"prompt": prompt,
"meta_path": meta_path,
"image_root": image_root,
"image_path": image_path,
"split": self.config.name,
},
)
================================================
FILE: deepcompressor/app/diffusion/dataset/data/MJHQ/__init__.py
================================================
================================================
FILE: deepcompressor/app/diffusion/dataset/data/__init__.py
================================================
import os
import random
import datasets
import yaml
__all__ = ["get_dataset"]
def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
meta = yaml.safe_load(open(meta_path, "r"))
names = list(meta.keys())
if max_dataset_size > 0:
random.Random(0).shuffle(names)
names = names[:max_dataset_size]
names = sorted(names)
ret = {"filename": [], "prompt": [], "meta_path": []}
idx = 0
for name in names:
prompt = meta[name]
for j in range(repeat):
ret["filename"].append(f"{name}-{j}")
ret["prompt"].append(prompt)
ret["meta_path"].append(meta_path)
idx += 1
return ret
def get_dataset(
name: str,
config_name: str | None = None,
split: str = "train",
max_dataset_size: int = -1,
return_gt: bool = False,
repeat: int = 4,
chunk_start: int = 0,
chunk_step: int = 1,
) -> datasets.Dataset:
prefix = os.path.dirname(__file__)
kwargs = {
"name": config_name,
"split": split,
"trust_remote_code": True,
"token": True,
"max_dataset_size": max_dataset_size,
}
if name.endswith((".yaml", ".yml")):
dataset = datasets.Dataset.from_dict(
load_dataset_yaml(name, max_dataset_size=max_dataset_size, repeat=repeat),
features=datasets.Features(
{
"filename": datasets.Value("string"),
"prompt": datasets.Value("string"),
"meta_path": datasets.Value("string"),
}
),
)
else:
path = os.path.join(prefix, f"{name}")
if name == "COCO":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "DCI":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
elif name == "MJHQ":
dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
else:
raise ValueError(f"Unknown dataset name: {name}")
assert not hasattr(dataset, "_unchunk_size")
assert not hasattr(dataset, "_chunk_start")
assert not hasattr(dataset, "_chunk_step")
unchunk_size = len(dataset)
if chunk_step > 1 or chunk_start > 0:
assert 0 <= chunk_start < chunk_step
dataset = dataset.select(range(chunk_start, len(dataset), chunk_step))
else:
chunk_start, chunk_step = 0, 1
dataset._unchunk_size = unchunk_size
dataset._chunk_start = chunk_start
dataset._chunk_step = chunk_step
return dataset
================================================
FILE: deepcompressor/app/diffusion/dataset/data/dump.py
================================================
import argparse
import os
import yaml
from tqdm import tqdm
from ...utils import get_control
from . import get_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--benchmarks", type=str, nargs="*", default=["COCO", "DCI", "MJHQ"])
parser.add_argument("--max-dataset-size", type=int, default=-1)
parser.add_argument("--dump-root", type=str, default="benchmarks")
parser.add_argument("--copy-images", action="store_true")
parser.add_argument("--prompts-only", action="store_true")
parser.add_argument("--controls", type=str, nargs="*", default=["canny-to-image", "depth-to-image", "inpainting"])
parser.add_argument("--chunk-start", type=int, default=0)
parser.add_argument("--chunk-step", type=int, default=1)
args = parser.parse_args()
if "depth-to-image" in args.controls:
from image_gen_aux import DepthPreprocessor
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf").to("cuda")
for benchmark in args.benchmarks:
dataset = get_dataset(
benchmark,
max_dataset_size=args.max_dataset_size,
return_gt=True,
chunk_start=args.chunk_start,
chunk_step=args.chunk_step,
)
prompts = {}
benchmark_root = os.path.join(args.dump_root, benchmark, f"{dataset.config_name}-{dataset._unchunk_size}")
for row in tqdm(dataset, desc=f"Dumping {dataset.config_name}"):
prompts[row["filename"]] = row["prompt"]
if not args.prompts_only:
image = row.get("image", None)
if image is not None:
image_root = os.path.join(benchmark_root, "images")
os.makedirs(image_root, exist_ok=True)
if args.copy_images:
image.save(os.path.join(image_root, row["filename"] + ".png"))
else:
ext = os.path.basename(row["image_path"]).split(".")[-1]
os.symlink(
os.path.abspath(os.path.expanduser(row["image_path"])),
os.path.abspath(os.path.expanduser(os.path.join(image_root, row["filename"] + f".{ext}"))),
)
if "canny-to-image" in args.controls:
canny_root = os.path.join(benchmark_root, "canny_images")
os.makedirs(canny_root, exist_ok=True)
canny = get_control("canny-to-image", image)
canny.save(os.path.join(canny_root, row["filename"] + ".png"))
if "depth-to-image" in args.controls:
depth_root = os.path.join(benchmark_root, "depth_images")
os.makedirs(depth_root, exist_ok=True)
depth = get_control("depth-to-image", image, processor=processor)
depth.save(os.path.join(depth_root, row["filename"] + ".png"))
if "inpainting" in args.controls:
mask_root = os.path.join(benchmark_root, "mask_images")
cropped_image_root = os.path.join(benchmark_root, "cropped_images")
os.makedirs(mask_root, exist_ok=True)
os.makedirs(cropped_image_root, exist_ok=True)
cropped_image, mask_image = get_control("inpainting", image, names=row["filename"])
cropped_image.save(os.path.join(cropped_image_root, row["filename"] + ".png"))
mask_image.save(os.path.join(mask_root, row["filename"] + ".png"))
if args.chunk_step == 1:
os.makedirs(benchmark_root, exist_ok=True)
with open(os.path.join(benchmark_root, "prompts.yaml"), "w") as f:
yaml.dump(prompts, f)
================================================
FILE: deepcompressor/app/diffusion/eval/__init__.py
================================================
# -*- coding: utf-8 -*-
from .config import DiffusionEvalConfig
================================================
FILE: deepcompressor/app/diffusion/eval/config.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion model evaluation."""
import logging
import os
import typing as tp
from dataclasses import dataclass, field
import datasets
import diffusers
import omniconfig
import torch
from diffusers import DiffusionPipeline
from omniconfig import configclass
from torch import multiprocessing as mp
from tqdm import tqdm
from deepcompressor.app.diffusion.dataset.data import get_dataset
from deepcompressor.utils.common import hash_str_to_int
from ..utils import get_control
from .metrics import compute_image_metrics
__all__ = ["DiffusionEvalConfig"]
@configclass
@dataclass
class DiffusionEvalConfig:
"""Diffusion model evaluation configuration.
Args:
protocol (`str`):
The protocol of the evaluation pipeline.
num_gpus (`int`, *optional*, defaults to `1`):
The number of GPUs to use.
batch_size (`int`, *optional*, defaults to `1`):
The batch size used for inference.
batch_size_per_gpu (`int`, *optional*, defaults to `None`):
The batch size per GPU.
height (`int`, *optional*, defaults to `None`):
The height of the generated images.
width (`int`, *optional*, defaults to `None`):
The width of the generated images.
clean_caption (`bool`, *optional*, defaults to `None`):
Whether to clean the caption.
num_steps (`int`, *optional*, defaults to `None`):
The number of inference steps.
guidance_scale (`float`, *optional*, defaults to `None`):
The guidance scale.
num_samples (`int`, *optional*, defaults to `1024`):
The number of samples to generate.
benchmarks (`list[str]`, *optional*, defaults to `["COCO", "DCI", "MJHQ", "GenEval"]`):
The benchmark datasets to evaluate on.
gt_metrics (`list[str]`, *optional*, defaults to `["clip_iqa", "clip_score", "psnr", "lpips", "ssim", "fid"]`):
The ground truth metrics to compute.
ref_metrics (`list[str]`, *optional*, defaults to `["psnr", "lpips", "ssim", "fid"]`):
The reference metrics to compute.
ref_root (`str`, *optional*, defaults to `""`):
The root directory path to the reference images.
gt_stats_root (`str`, *optional*, defaults to `""`):
The root directory path to the ground truth statistics.
chunk_start (`int`, *optional*, defaults to `0`):
The starting chunk index.
chunk_step (`int`, *optional*, defaults to `1`):
The chunk step size.
"""
protocol: str
num_gpus: int = field(default=1, metadata={omniconfig.ARGPARSE_ARGS: ("--num-gpus", "-n")})
batch_size: int = 1
batch_size_per_gpu: int | None = None
height: int | None = None
width: int | None = None
clean_caption: bool | None = None
num_steps: int | None = None
guidance_scale: float | None = None
num_samples: int = 1024
benchmarks: list[str] = field(
default_factory=lambda: ["COCO", "DCI", "MJHQ", "GenEval"],
metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
)
gt_metrics: list[str] = field(
default_factory=lambda: ["clip_iqa", "clip_score", "image_reward", "fid"],
metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
)
ref_metrics: list[str] = field(
default_factory=lambda: ["psnr", "lpips", "ssim"],
metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
)
gen_root: str = ""
ref_root: str = ""
gt_stats_root: str = ""
control_root: str | None = None
chunk_start: int = 0
chunk_step: int = 1
chunk_only: bool = False
def __post_init__(self):
assert self.protocol
self.protocol = self.protocol.lower().format(num_steps=self.num_steps, guidance_scale=self.guidance_scale)
assert 0 <= self.chunk_start < self.chunk_step
if self.chunk_start == 0 and self.chunk_step == 1:
self.chunk_only = False
def get_pipeline_kwargs(self) -> dict[str, tp.Any]:
kwargs = {}
if self.height is not None:
kwargs["height"] = self.height
if self.width is not None:
kwargs["width"] = self.width
if self.clean_caption is not None:
kwargs["clean_caption"] = self.clean_caption
if self.num_steps is not None:
kwargs["num_inference_steps"] = self.num_steps
if self.guidance_scale is not None:
kwargs["guidance_scale"] = self.guidance_scale
return kwargs
def _generate(
self,
rank: int,
dataset: datasets.Dataset,
pipeline: DiffusionPipeline,
dirpath: str,
logger: logging.Logger,
dataset_name: str | None = None,
task: str = "text-to-image",
control_root: str | None = None,
) -> None:
if self.num_gpus > 1:
pipeline = pipeline.to(rank)
if rank == 0:
logger.info(
f" {dataset.config_name} has {len(dataset)} samples "
f"(chunk_start={dataset._chunk_start}, chunk_step={dataset._chunk_step},"
f" unchunk_size={dataset._unchunk_size})"
)
pipeline.set_progress_bar_config(
desc="Sampling",
leave=False,
dynamic_ncols=True,
position=1,
disable=self.num_gpus > 1,
)
if dataset_name is None:
dataset_name = dataset.config_name
for batch in tqdm(
dataset.iter(batch_size=self.batch_size, drop_last_batch=False),
desc=dataset_name if self.num_gpus == 1 else f"{dataset_name} (GPU {rank})",
leave=False,
dynamic_ncols=True,
position=rank,
total=(len(dataset) + self.batch_size - 1) // self.batch_size,
):
filenames = batch["filename"][rank :: self.num_gpus]
if len(filenames) == 0:
continue
if all(os.path.exists(os.path.join(dirpath, f"{filename}.png")) for filename in filenames):
continue
prompts = batch["prompt"][rank :: self.num_gpus]
seeds = [hash_str_to_int(name) for name in filenames]
diffusers.training_utils.set_seed(seeds[0])
generators = [torch.Generator().manual_seed(seed) for seed in seeds]
pipeline_kwargs = self.get_pipeline_kwargs()
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
controls = get_control(
task,
batch["image"],
names=batch["filename"],
data_root=os.path.join(control_root, f"{dataset_name}-{dataset._unchunk_size}"),
)
if task == "inpainting":
pipeline_kwargs["image"] = controls[0]
pipeline_kwargs["mask_image"] = controls[1]
else:
pipeline_kwargs["control_image"] = controls
output = pipeline(prompts, generator=generators, **pipeline_kwargs)
images = output.images
for filename, image in zip(filenames, images, strict=True):
image.save(os.path.join(dirpath, f"{filename}.png"))
def generate(
self,
pipeline: DiffusionPipeline,
gen_root: str = "",
task: str = "text-to-image",
) -> None:
logger = logging.getLogger(f"{__name__}.DiffusionEval")
gen_root = gen_root or self.gen_root
for benchmark in self.benchmarks:
dataset = get_dataset(
benchmark,
max_dataset_size=self.num_samples,
chunk_start=self.chunk_start,
chunk_step=self.chunk_step,
return_gt=task in ["canny-to-image"],
repeat=1,
)
if benchmark.endswith(".yaml") or benchmark.endswith(".yml"):
dataset_name = os.path.splitext(os.path.basename(benchmark))[0]
dirpath = os.path.join(
gen_root,
"samples",
"YAML",
f"{dataset_name}-{dataset._unchunk_size}",
)
else:
dataset_name = dataset.config_name
dirpath = os.path.join(
gen_root,
"samples",
benchmark,
f"{dataset.config_name}-{dataset._unchunk_size}",
)
if self.chunk_only:
dirpath += f".{dataset._chunk_start}.{dataset._chunk_step}"
os.makedirs(dirpath, exist_ok=True)
args = (dataset, pipeline, dirpath, logger, dataset_name, task, os.path.join(self.control_root, benchmark))
if self.num_gpus == 1:
self._generate(0, *args)
else:
mp.spawn(self._generate, args=args, nprocs=self.num_gpus, join=True)
def evaluate(
self, pipeline: DiffusionPipeline, gen_root: str = "", skip_gen: bool = False, task: str = "text-to-image"
) -> dict[str, tp.Any] | None:
gen_root = gen_root or self.gen_root
if not skip_gen:
self.generate(pipeline, gen_root=gen_root, task=task)
if not self.chunk_only:
return compute_image_metrics(
gen_root=gen_root,
benchmarks=self.benchmarks,
max_dataset_size=self.num_samples,
chunk_start=self.chunk_start,
chunk_step=self.chunk_step,
ref_root=self.ref_root,
gt_stats_root=self.gt_stats_root,
gt_metrics=self.gt_metrics,
ref_metrics=self.ref_metrics,
)
else:
return {}
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/__init__.py
================================================
import logging
import os
from deepcompressor.app.diffusion.dataset.data import get_dataset
from .fid import compute_fid
from .image_reward import compute_image_reward
from .multimodal import compute_image_multimodal_metrics
from .similarity import compute_image_similarity_metrics
logging.getLogger("PIL").setLevel(logging.WARNING)
__all__ = ["compute_image_metrics"]
def compute_image_metrics(
gen_root: str,
benchmarks: str | tuple[str, ...] = ("DCI", "GenAIBench", "GenEval", "MJHQ", "T2ICompBench"),
max_dataset_size: int = -1,
chunk_start: int = 0,
chunk_step: int = 1,
chunk_only: bool = False,
ref_root: str = "",
gt_stats_root: str = "",
gt_metrics: tuple[str, ...] = ("clip_iqa", "clip_score", "image_reward", "fid"),
ref_metrics: tuple[str, ...] = ("psnr", "lpips", "ssim", "fid"),
) -> dict:
if chunk_start == 0 and chunk_step == 1:
chunk_only = False
assert chunk_start == 0 and chunk_step == 1, "Chunking is not supported for image data."
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if isinstance(benchmarks, str):
benchmarks = (benchmarks,)
gt_multimodal_metrics, gt_similarity_metrics, gt_other_metrics = categorize_metrics(gt_metrics)
_, ref_similarity_metrics, ref_other_metrics = categorize_metrics(ref_metrics)
results = {}
for benchmark in benchmarks:
benchmark_results = {}
dataset = get_dataset(benchmark, max_dataset_size=max_dataset_size, return_gt=True)
dirname = f"{dataset.config_name}-{dataset._unchunk_size}"
if dataset._chunk_start == 0 and dataset._chunk_step == 1:
filename = f"{dirname}.npz"
else:
filename = os.path.join(dirname, f"{dataset._chunk_start}-{dataset._chunk_step}.npz")
if chunk_only:
dirname += f".{dataset._chunk_start}.{dataset._chunk_step}"
gen_dirpath = os.path.join(gen_root, "samples", benchmark, dirname)
if gt_metrics:
gt_results = compute_image_multimodal_metrics(dataset, gen_dirpath, metrics=gt_multimodal_metrics)
if "image_reward" in gt_other_metrics:
gt_results.update(compute_image_reward(dataset, gen_dirpath))
if benchmark in ("COCO", "DCI", "MJHQ"):
gt_results.update(compute_image_similarity_metrics(dataset, gen_dirpath, metrics=gt_similarity_metrics))
if "fid" in gt_other_metrics:
gt_results["fid"] = compute_fid(
dataset,
gen_dirpath,
ref_cache_path=(os.path.join(gt_stats_root, benchmark, filename) if gt_stats_root else None),
gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
)
benchmark_results["with_gt"] = gt_results
if ref_root and ref_metrics:
assert os.path.exists(ref_root), f"Reference root directory {ref_root} does not exist."
ref_dirpath = os.path.join(ref_root, "samples", benchmark, dirname)
ref_results = compute_image_similarity_metrics(ref_dirpath, gen_dirpath, metrics=ref_similarity_metrics)
if "fid" in ref_other_metrics:
ref_results["fid"] = compute_fid(
ref_dirpath,
gen_dirpath,
ref_cache_path=os.path.join(ref_root, "fid_stats", benchmark, filename),
gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
)
benchmark_results["with_orig"] = ref_results
print(f"{dirname} results:")
print(benchmark_results)
results[dirname] = benchmark_results
return results
def categorize_metrics(metrics: tuple[str, ...]) -> tuple[list[str], list[str], list[str]]:
"""
Categorize metrics into multimodal, similarity, and other metrics.
Args:
metrics (tuple[str, ...]): List of metrics.
Returns:
tuple[list[str], list[str], list[str]]: Tuple of multimodal, similarity, and other metrics.
"""
metrics = tuple(set(metrics))
multimodal_metrics, similarity_metrics, other_metrics = [], [], []
for metric in metrics:
if metric in ("clip_iqa", "clip_score"):
multimodal_metrics.append(metric)
elif metric in ("psnr", "lpips", "ssim"):
similarity_metrics.append(metric)
else:
other_metrics.append(metric)
return multimodal_metrics, similarity_metrics, other_metrics
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/fid.py
================================================
import os
from datetime import datetime
import numpy as np
import torch
import torchvision
from cleanfid import fid
from cleanfid.resize import build_resizer
from datasets import Dataset
from tqdm import tqdm
__all__ = ["compute_fid"]
def get_dataset_features(
dataset: Dataset,
model,
mode: str = "clean",
batch_size: int = 128,
device: str | torch.device = "cuda",
) -> np.ndarray:
to_tensor = torchvision.transforms.ToTensor()
fn_resize = build_resizer(mode)
np_feats = []
for batch in tqdm(
dataset.iter(batch_size=batch_size, drop_last_batch=False),
desc=f"Extracting {dataset.config_name} features",
total=(len(dataset) + batch_size - 1) // batch_size,
):
resized_images = [fn_resize(np.array(image.convert("RGB"))) for image in batch["image"]]
image_tensors = []
for resized_image in resized_images:
if resized_image.dtype == "uint8":
image_tensor = to_tensor(resized_image) * 255
else:
image_tensor = to_tensor(resized_image)
image_tensors.append(image_tensor)
image_tensors = torch.stack(image_tensors, dim=0)
np_feats.append(fid.get_batch_features(image_tensors, model, device))
np_feats = np.concatenate(np_feats, axis=0)
return np_feats
def get_fid_features(
dataset_or_folder: str | Dataset | None = None,
cache_path: str | None = None,
num: int | None = None,
mode: str = "clean",
num_workers: int = 8,
batch_size: int = 64,
device: str | torch.device = "cuda",
force_overwrite: bool = False,
verbose: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
if cache_path is not None and os.path.exists(cache_path) and not force_overwrite:
npz = np.load(cache_path)
mu, sigma = npz["mu"], npz["sigma"]
else:
feat_model = fid.build_feature_extractor(mode, device)
if isinstance(dataset_or_folder, str):
np_feats = fid.get_folder_features(
dataset_or_folder,
feat_model,
num_workers=num_workers,
num=num,
batch_size=batch_size,
device=device,
verbose=verbose,
mode=mode,
description=f"Extracting {dataset_or_folder} features",
)
else:
assert isinstance(dataset_or_folder, Dataset)
np_feats = get_dataset_features(
dataset_or_folder, model=feat_model, mode=mode, batch_size=batch_size, device=device
)
mu = np.mean(np_feats, axis=0)
sigma = np.cov(np_feats, rowvar=False)
if cache_path is not None:
os.makedirs(os.path.abspath(os.path.dirname(cache_path)), exist_ok=True)
np.savez(cache_path, mu=mu, sigma=sigma)
return mu, sigma
def compute_fid(
ref_dirpath_or_dataset: str | Dataset,
gen_dirpath: str,
ref_cache_path: str | None = None,
gen_cache_path: str | None = None,
use_symlink: bool = True,
timestamp: str | None = None,
) -> float:
sym_ref_dirpath, sym_gen_dirpath = None, None
if use_symlink:
if timestamp is None:
timestamp = datetime.now().strftime("%y%m%d.%H%M%S")
os.makedirs(".tmp", exist_ok=True)
if isinstance(ref_dirpath_or_dataset, str):
sym_ref_dirpath = os.path.join(".tmp", f"ref-{hash(str(ref_dirpath_or_dataset))}-{timestamp}")
os.symlink(os.path.abspath(ref_dirpath_or_dataset), os.path.abspath(sym_ref_dirpath))
ref_dirpath_or_dataset = sym_ref_dirpath
sym_gen_dirpath = os.path.join(".tmp", f"gen-{hash(str(gen_dirpath))}-{timestamp}")
os.symlink(os.path.abspath(gen_dirpath), os.path.abspath(sym_gen_dirpath))
gen_dirpath = sym_gen_dirpath
mu1, sigma1 = get_fid_features(dataset_or_folder=ref_dirpath_or_dataset, cache_path=ref_cache_path)
mu2, sigma2 = get_fid_features(dataset_or_folder=gen_dirpath, cache_path=gen_cache_path)
fid_score = fid.frechet_distance(mu1, sigma1, mu2, sigma2)
fid_score = float(fid_score)
if use_symlink:
if sym_ref_dirpath is not None:
os.remove(sym_ref_dirpath)
os.remove(sym_gen_dirpath)
return fid_score
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/image_reward.py
================================================
import os
import datasets
import torch
from tqdm import tqdm
__all__ = ["compute_image_reward"]
def compute_image_reward(
ref_dataset: datasets.Dataset,
gen_dirpath: str,
) -> dict[str, float]:
# import here to remove dependency on `ImageReward` git repo
import ImageReward as RM
scores = []
model = RM.load("ImageReward-v1.0")
for batch in tqdm(
ref_dataset.iter(batch_size=1, drop_last_batch=False),
desc=f"{ref_dataset.config_name} image reward",
total=len(ref_dataset),
dynamic_ncols=True,
):
filename = batch["filename"][0]
path = os.path.join(gen_dirpath, f"{filename}.png")
prompt = batch["prompt"][0]
with torch.inference_mode():
score = model.score(prompt, path)
scores.append(score)
result = {"image_reward": sum(scores) / len(scores)}
return result
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/multimodal.py
================================================
import os
import datasets
import numpy as np
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.multimodal import CLIPImageQualityAssessment, CLIPScore
from tqdm import tqdm
__all__ = ["compute_image_multimodal_metrics"]
class PromptImageDataset(data.Dataset):
def __init__(self, ref_dataset: datasets.Dataset, gen_dirpath: str):
super(data.Dataset, self).__init__()
self.ref_dataset, self.gen_dirpath = ref_dataset, gen_dirpath
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.ref_dataset)
def __getitem__(self, idx: int):
row = self.ref_dataset[idx]
gen_image = Image.open(os.path.join(self.gen_dirpath, row["filename"] + ".png")).convert("RGB")
gen_tensor = torch.from_numpy(np.array(gen_image)).permute(2, 0, 1)
prompt = row["prompt"]
return [gen_tensor, prompt]
def compute_image_multimodal_metrics(
ref_dataset: datasets.Dataset,
gen_dirpath: str,
metrics: tuple[str, ...] = ("clip_iqa", "clip_score"),
batch_size: int = 64,
num_workers: int = 8,
device: str | torch.device = "cuda",
) -> dict[str, float]:
if len(metrics) == 0:
return {}
metric_names = metrics
metrics: dict[str, torchmetrics.Metric] = {}
for metric_name in metric_names:
if metric_name == "clip_iqa":
metric = CLIPImageQualityAssessment(model_name_or_path="openai/clip-vit-large-patch14").to(device)
elif metric_name == "clip_score":
metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
else:
raise NotImplementedError(f"Metric {metric_name} is not implemented")
metrics[metric_name] = metric
dataset = PromptImageDataset(ref_dataset, gen_dirpath)
dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)
with torch.no_grad():
for batch in tqdm(dataloader, desc=f"{ref_dataset.config_name} multimodal metrics"):
batch[0] = batch[0].to(device)
for metric_name, metric in metrics.items():
if metric_name == "clip_iqa":
metric.update(batch[0].to(torch.float32))
else:
prompts = list(batch[1])
metric.update(batch[0], prompts)
result = {metric_name: metric.compute().mean().item() for metric_name, metric in metrics.items()}
return result
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/run.py
================================================
# -*- coding: utf-8 -*-
"""Evaluate generated images or videos using the specified metrics."""
import json
import os
from ...config import DiffusionPtqRunConfig
if __name__ == "__main__":
config, _, unused_cfgs, unused_args, unknown_args = DiffusionPtqRunConfig.get_parser().parse_known_args()
assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
assert len(unused_cfgs) == 0, f"Unused configurations: {unused_cfgs}"
assert unused_args is None, f"Unused arguments: {unused_args}"
assert isinstance(config, DiffusionPtqRunConfig)
results = config.eval.evaluate(pipeline=None, skip_gen=True, task=config.pipeline.task)
save_path = os.path.join(config.eval.gen_root, f"results-{config.output.timestamp}.json")
os.makedirs(os.path.abspath(os.path.dirname(save_path)), exist_ok=True)
with open(save_path, "w") as f:
json.dump(results, f, indent=2, sort_keys=True)
print(results)
================================================
FILE: deepcompressor/app/diffusion/eval/metrics/similarity.py
================================================
import os
import datasets
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import (
LearnedPerceptualImagePatchSimilarity,
PeakSignalNoiseRatio,
StructuralSimilarityIndexMeasure,
)
from tqdm import tqdm
__all__ = ["compute_image_similarity_metrics"]
class MultiImageDataset(data.Dataset):
def __init__(self, gen_dirpath: str, ref_dirpath_or_dataset: str | datasets.Dataset):
super(data.Dataset, self).__init__()
self.gen_names = sorted(
[name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")]
)
self.gen_dirpath, self.ref_dirpath_or_dataset = gen_dirpath, ref_dirpath_or_dataset
if isinstance(ref_dirpath_or_dataset, str):
self.ref_names = sorted(
[name for name in os.listdir(ref_dirpath_or_dataset) if name.endswith(".png") or name.endswith(".jpg")]
)
assert len(self.ref_names) == len(self.gen_names)
else:
assert isinstance(ref_dirpath_or_dataset, datasets.Dataset)
self.ref_names = self.gen_names
assert len(ref_dirpath_or_dataset) == len(self.gen_names)
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.ref_names)
def __getitem__(self, idx: int):
if isinstance(self.ref_dirpath_or_dataset, str):
name = self.ref_names[idx]
assert name == self.gen_names[idx]
ref_image = Image.open(os.path.join(self.ref_dirpath_or_dataset, name)).convert("RGB")
else:
row = self.ref_dirpath_or_dataset[idx]
ref_image = row["image"].convert("RGB")
name = row["filename"] + ".png"
gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB")
gen_size = gen_image.size
ref_size = ref_image.size
if ref_size != gen_size:
ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
gen_tensor = self.transform(gen_image)
ref_tensor = self.transform(ref_image)
return [gen_tensor, ref_tensor]
def compute_image_similarity_metrics(
ref_dirpath_or_dataset: str | datasets.Dataset,
gen_dirpath: str,
metrics: tuple[str, ...] = ("psnr", "lpips", "ssim"),
batch_size: int = 64,
num_workers: int = 8,
device: str | torch.device = "cuda",
) -> dict[str, float]:
if len(metrics) == 0:
return {}
metric_names = metrics
metrics: dict[str, torchmetrics.Metric] = {}
for metric_name in metric_names:
if metric_name == "psnr":
metric = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to(device)
elif metric_name == "lpips":
metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
elif metric_name == "ssim":
metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to(device)
else:
raise NotImplementedError(f"Metric {metric_name} is not implemented")
metrics[metric_name] = metric
dataset = MultiImageDataset(gen_dirpath, ref_dirpath_or_dataset)
dataloader = data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
)
with torch.no_grad():
desc = (
ref_dirpath_or_dataset.config_name
if isinstance(ref_dirpath_or_dataset, datasets.Dataset)
else os.path.basename(ref_dirpath_or_dataset)
) + " similarity metrics"
for batch in tqdm(dataloader, desc=desc):
batch = [tensor.to(device) for tensor in batch]
for metric in metrics.values():
metric.update(batch[0], batch[1])
result = {metric_name: metric.compute().item() for metric_name, metric in metrics.items()}
return result
================================================
FILE: deepcompressor/app/diffusion/nn/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/app/diffusion/nn/attention.py
================================================
# -*- coding: utf-8 -*-
import typing as tp
import diffusers
import packaging.version
import torch
import torch.nn as nn
from diffusers.models.attention_processor import (
Attention,
AttnProcessor2_0,
FluxAttnProcessor2_0,
JointAttnProcessor2_0,
)
from deepcompressor.nn.patch.sdpa import ScaleDotProductAttention
__all__ = ["DiffusionAttentionProcessor"]
if packaging.version.Version(diffusers.__version__) >= packaging.version.Version("0.31"):
from diffusers.models.embeddings import apply_rotary_emb
def apply_flux_rope(query, key, image_rotary_emb):
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
return query, key
else:
from diffusers.models.attention_processor import apply_rope as apply_flux_rope
class DiffusionAttentionProcessor(nn.Module):
def __init__(
self,
orig: AttnProcessor2_0 | FluxAttnProcessor2_0 | JointAttnProcessor2_0,
sdpa: ScaleDotProductAttention | None = None,
) -> None:
super().__init__()
self.orig = orig
if orig.__class__.__name__.startswith("Flux"):
self.rope = apply_flux_rope
elif isinstance(orig, (AttnProcessor2_0, JointAttnProcessor2_0)):
self.rope = None
else:
raise NotImplementedError(f"Unsupported AttentionProcessor: {orig}")
self.sdpa = sdpa or ScaleDotProductAttention()
def __call__( # noqa: C901
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: tp.Optional[torch.Tensor] = None,
attention_mask: tp.Optional[torch.Tensor] = None,
image_rotary_emb: tp.Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert len(args) == 0 and kwargs.get("scale", None) is None
assert attn.spatial_norm is None
assert attn.group_norm is None
assert attn.norm_cross is None
assert not attn.residual_connection
assert attn.rescale_output_factor == 1.0
heads = attn.heads
head_dim = attn.inner_dim // heads
kv_heads = attn.inner_kv_dim // head_dim
assert attn.scale == head_dim**-0.5
input_ndim, input_shape = hidden_states.dim(), hidden_states.size()
if input_ndim > 3:
hidden_states = hidden_states.view(input_shape[0], input_shape[1], -1).transpose(1, 2)
batch_size, input_length, _ = hidden_states.shape
context_ndim, context_shape, context_length = None, None, None
if encoder_hidden_states is not None:
context_ndim, context_shape = encoder_hidden_states.ndim, encoder_hidden_states.shape
assert context_shape[0] == batch_size
if context_ndim > 3:
encoder_hidden_states = encoder_hidden_states.view(batch_size, context_shape[1], -1).transpose(1, 2)
context_length = encoder_hidden_states.shape[1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, context_length or input_length, batch_size)
attention_mask = attention_mask.view(batch_size, heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key, value, add_query, add_key, add_value = None, None, None, None, None
if hasattr(attn, "add_k_proj"):
if attn.to_k is not None:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
add_key = attn.add_k_proj(encoder_hidden_states)
add_value = attn.add_v_proj(encoder_hidden_states)
if hasattr(attn, "add_q_proj"):
add_query = attn.add_q_proj(encoder_hidden_states)
else:
if attn.is_cross_attention:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
else:
assert encoder_hidden_states is None
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
hidden_states, encoder_hidden_states = None, None
query = query.view(batch_size, -1, heads, head_dim).transpose(1, 2)
if key is not None:
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if add_query is not None:
add_query = add_query.view(batch_size, -1, heads, head_dim).transpose(1, 2)
if add_key is not None:
add_key = add_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
add_value = add_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
if kv_heads != heads:
heads_per_kv_head = heads // kv_heads
if key is not None:
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
if add_key is not None:
add_key = torch.repeat_interleave(add_key, heads_per_kv_head, dim=1)
add_value = torch.repeat_interleave(add_value, heads_per_kv_head, dim=1)
if attn.norm_q is not None:
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.norm_added_q is not None:
add_query = attn.norm_added_q(add_query)
add_key = attn.norm_added_k(add_key)
if add_query is not None:
query = torch.cat([add_query, query], dim=2)
if add_key is not None:
if key is None:
key, value = add_key, add_value
else:
key = torch.cat([add_key, key], dim=2)
value = torch.cat([add_value, value], dim=2)
del add_query, add_key, add_value
if image_rotary_emb is not None:
query, key = self.rope(query, key, image_rotary_emb)
hidden_states = self.sdpa(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.inner_dim)
hidden_states = hidden_states.to(query.dtype)
if hidden_states.shape[1] > input_length:
encoder_hidden_states = hidden_states[:, :context_length]
hidden_states = hidden_states[:, context_length:]
if hasattr(attn, "to_out"):
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim > 3:
hidden_states = hidden_states.transpose(-1, -2).reshape(input_shape)
if encoder_hidden_states is not None and context_ndim > 3:
assert encoder_hidden_states.ndim == 3
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(context_shape)
if encoder_hidden_states is None:
return hidden_states
return hidden_states, encoder_hidden_states
================================================
FILE: deepcompressor/app/diffusion/nn/patch.py
================================================
import torch.nn as nn
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock
from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.utils import patch, tools
from .attention import DiffusionAttentionProcessor
from .struct import DiffusionFeedForwardStruct, DiffusionModelStruct, DiffusionResnetStruct, UNetStruct
__all__ = [
"replace_up_block_conv_with_concat_conv",
"replace_fused_linear_with_concat_linear",
"replace_attn_processor",
"shift_input_activations",
]
def replace_up_block_conv_with_concat_conv(model: nn.Module) -> None:
"""Replace up_block convolutions in UNet with ConcatConv."""
model_struct = DiffusionModelStruct.construct(model)
if not isinstance(model_struct, UNetStruct):
return
logger = tools.logging.getLogger(__name__)
logger.info("Replacing up_block convolutions with ConcatConv.")
tools.logging.Formatter.indent_inc()
parents_map = patch.get_module_parents_map(model)
for up_block in model_struct.up_block_structs:
logger.info(f"+ Replacing convolutions in up_block {up_block.name}")
tools.logging.Formatter.indent_inc()
for resnet in up_block.resnet_structs:
assert len(resnet.convs[0]) == 1
conv, conv_name = resnet.convs[0][0], resnet.conv_names[0][0]
logger.info(f"- Replacing {conv_name} in resnet {resnet.name}")
tools.logging.Formatter.indent_inc()
if resnet.idx == 0:
if up_block.idx == 0:
prev_block = model_struct.mid_block_struct
else:
prev_block = model_struct.up_block_structs[up_block.idx - 1]
logger.info(f"+ using previous block {prev_block.name}")
prev_channels = prev_block.resnet_structs[-1].convs[-1][-1].out_channels
else:
prev_channels = up_block.resnet_structs[resnet.idx - 1].convs[-1][-1].out_channels
logger.info(f"+ conv_in_channels = {prev_channels}/{conv.in_channels}")
logger.info(f"+ conv_out_channels = {conv.out_channels}")
concat_conv = ConcatConv2d.from_conv2d(conv, [prev_channels])
for parent_name, parent_module, child_name in parents_map[conv]:
logger.info(f"+ replacing {child_name} in {parent_name}")
setattr(parent_module, child_name, concat_conv)
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
def replace_fused_linear_with_concat_linear(model: nn.Module) -> None:
"""Replace fused Linear in FluxSingleTransformerBlock with ConcatLinear."""
logger = tools.logging.getLogger(__name__)
logger.info("Replacing fused Linear with ConcatLinear.")
tools.logging.Formatter.indent_inc()
for name, module in model.named_modules():
if isinstance(module, FluxSingleTransformerBlock):
logger.info(f"+ Replacing fused Linear in {name} with ConcatLinear.")
tools.logging.Formatter.indent_inc()
logger.info(f"- in_features = {module.proj_out.out_features}/{module.proj_out.in_features}")
logger.info(f"- out_features = {module.proj_out.out_features}")
tools.logging.Formatter.indent_dec()
module.proj_out = ConcatLinear.from_linear(module.proj_out, [module.proj_out.out_features])
tools.logging.Formatter.indent_dec()
def shift_input_activations(model: nn.Module) -> None:
"""Shift input activations of convolutions and linear layers if their lowerbound is negative.
Args:
model (nn.Module): model to shift input activations.
"""
logger = tools.logging.getLogger(__name__)
model_struct = DiffusionModelStruct.construct(model)
module_parents_map = patch.get_module_parents_map(model)
logger.info("- Shifting input activations.")
tools.logging.Formatter.indent_inc()
for _, module_name, module, parent, field_name in model_struct.named_key_modules():
lowerbound = None
if isinstance(parent, DiffusionResnetStruct) and field_name.startswith("conv"):
lowerbound = parent.config.intermediate_lowerbound
elif isinstance(parent, DiffusionFeedForwardStruct) and field_name.startswith("down_proj"):
lowerbound = parent.config.intermediate_lowerbound
if lowerbound is not None and lowerbound < 0:
shift = -lowerbound
logger.info(f"+ Shifting input activations of {module_name} by {shift}")
tools.logging.Formatter.indent_inc()
if isinstance(module, nn.Linear):
shifted = ShiftedLinear.from_linear(module, shift=shift)
shifted.linear.unsigned = True
elif isinstance(module, nn.Conv2d):
shifted = ShiftedConv2d.from_conv2d(module, shift=shift)
shifted.conv.unsigned = True
else:
raise NotImplementedError(f"Unsupported module type {type(module)}")
for parent_name, parent_module, child_name in module_parents_map[module]:
logger.info(f"+ Replacing {child_name} in {parent_name}")
setattr(parent_module, child_name, shifted)
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
def replace_attn_processor(model: nn.Module) -> None:
"""Replace Attention processor with DiffusionAttentionProcessor."""
logger = tools.logging.getLogger(__name__)
logger.info("Replacing Attention processors.")
tools.logging.Formatter.indent_inc()
for name, module in model.named_modules():
if isinstance(module, Attention):
logger.info(f"+ Replacing {name} processor with DiffusionAttentionProcessor.")
module.set_processor(DiffusionAttentionProcessor(module.processor))
tools.logging.Formatter.indent_dec()
================================================
FILE: deepcompressor/app/diffusion/nn/struct.py
================================================
# -*- coding: utf-8 -*-
"""Utility functions for Diffusion Models."""
import enum
import typing as tp
from abc import abstractmethod
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field
# region imports
import torch.nn as nn
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU
from diffusers.models.attention import BasicTransformerBlock, FeedForward, JointTransformerBlock
from diffusers.models.attention_processor import Attention, SanaLinearAttnProcessor2_0
from diffusers.models.embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
)
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.sana_transformer import GLUMBConv, SanaTransformer2DModel, SanaTransformerBlock
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.models.transformers.transformer_flux import (
FluxSingleTransformerBlock,
FluxTransformer2DModel,
FluxTransformerBlock,
)
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.models.unets.unet_2d import UNet2DModel
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UpBlock2D,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import (
FluxControlPipeline,
FluxFillPipeline,
FluxPipeline,
PixArtAlphaPipeline,
PixArtSigmaPipeline,
SanaPipeline,
StableDiffusion3Pipeline,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
)
from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.nn.struct.attn import (
AttentionConfigStruct,
AttentionStruct,
BaseTransformerStruct,
FeedForwardConfigStruct,
FeedForwardStruct,
TransformerBlockStruct,
)
from deepcompressor.nn.struct.base import BaseModuleStruct
from deepcompressor.utils.common import join_name
from .attention import DiffusionAttentionProcessor
# endregion
__all__ = ["DiffusionModelStruct", "DiffusionBlockStruct", "DiffusionModelStruct"]
DIT_BLOCK_CLS = tp.Union[
BasicTransformerBlock,
JointTransformerBlock,
FluxSingleTransformerBlock,
FluxTransformerBlock,
SanaTransformerBlock,
]
UNET_BLOCK_CLS = tp.Union[
DownBlock2D,
CrossAttnDownBlock2D,
UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UpBlock2D,
CrossAttnUpBlock2D,
]
DIT_CLS = tp.Union[
Transformer2DModel,
PixArtTransformer2DModel,
SD3Transformer2DModel,
FluxTransformer2DModel,
SanaTransformer2DModel,
]
UNET_CLS = tp.Union[UNet2DModel, UNet2DConditionModel]
MODEL_CLS = tp.Union[DIT_CLS, UNET_CLS]
UNET_PIPELINE_CLS = tp.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
DIT_PIPELINE_CLS = tp.Union[
StableDiffusion3Pipeline,
PixArtAlphaPipeline,
PixArtSigmaPipeline,
FluxPipeline,
FluxControlPipeline,
FluxFillPipeline,
SanaPipeline,
]
PIPELINE_CLS = tp.Union[UNET_PIPELINE_CLS, DIT_PIPELINE_CLS]
@dataclass(kw_only=True)
class DiffusionModuleStruct(BaseModuleStruct):
def named_key_modules(self) -> tp.Generator[tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
if isinstance(self.module, (nn.Linear, nn.Conv2d)):
yield self.key, self.name, self.module, self.parent, self.fname
else:
for name, module in self.module.named_modules():
if name and isinstance(module, (nn.Linear, nn.Conv2d)):
module_name = join_name(self.name, name)
field_name = join_name(self.fname, name)
yield self.key, module_name, module, self.parent, field_name
@dataclass(kw_only=True)
class DiffusionBlockStruct(BaseModuleStruct):
@abstractmethod
def iter_attention_structs(self) -> tp.Generator["DiffusionAttentionStruct", None, None]: ...
@abstractmethod
def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]: ...
@dataclass(kw_only=True)
class DiffusionModelStruct(DiffusionBlockStruct):
pre_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)
post_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)
@property
@abstractmethod
def num_blocks(self) -> int: ...
@property
@abstractmethod
def block_structs(self) -> list[DiffusionBlockStruct]: ...
@abstractmethod
def get_prev_module_keys(self) -> tuple[str, ...]: ...
@abstractmethod
def get_post_module_keys(self) -> tuple[str, ...]: ...
@abstractmethod
def _get_iter_block_activations_args(
self, **input_kwargs
) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]: ...
def _get_iter_pre_module_activations_args(
self,
) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
for layer_struct in self.pre_module_structs.values():
layers.append(layer_struct.module)
layer_structs.append(layer_struct)
recomputes.append(False)
use_prev_layer_outputs.append(False)
return layers, layer_structs, recomputes, use_prev_layer_outputs
def _get_iter_post_module_activations_args(
self,
) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
for layer_struct in self.post_module_structs.values():
layers.append(layer_struct.module)
layer_structs.append(layer_struct)
recomputes.append(False)
use_prev_layer_outputs.append(False)
return layers, layer_structs, recomputes, use_prev_layer_outputs
def get_iter_layer_activations_args(
self, skip_pre_modules: bool, skip_post_modules: bool, **input_kwargs
) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
"""
Get the arguments for iterating over the layers and their activations.
Args:
skip_pre_modules (`bool`):
Whether to skip the pre-modules
skip_post_modules (`bool`):
Whether to skip the post-modules
Returns:
`tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]`:
the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
"""
layers, structs, recomputes, uses = [], [], [], []
if not skip_pre_modules:
layers, structs, recomputes, uses = self._get_iter_pre_module_activations_args()
_layers, _structs, _recomputes, _uses = self._get_iter_block_activations_args(**input_kwargs)
layers.extend(_layers)
structs.extend(_structs)
recomputes.extend(_recomputes)
uses.extend(_uses)
if not skip_post_modules:
_layers, _structs, _recomputes, _uses = self._get_iter_post_module_activations_args()
layers.extend(_layers)
structs.extend(_structs)
recomputes.extend(_recomputes)
uses.extend(_uses)
return layers, structs, recomputes, uses
def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
for module in self.pre_module_structs.values():
yield from module.named_key_modules()
for block in self.block_structs:
yield from block.named_key_modules()
for module in self.post_module_structs.values():
yield from module.named_key_modules()
def iter_attention_structs(self) -> tp.Generator["AttentionStruct", None, None]:
for block in self.block_structs:
yield from block.iter_attention_structs()
def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]:
for block in self.block_structs:
yield from block.iter_transformer_block_structs()
def get_named_layers(
self, skip_pre_modules: bool, skip_post_modules: bool, skip_blocks: bool = False
) -> OrderedDict[str, DiffusionBlockStruct | DiffusionModuleStruct]:
named_layers = OrderedDict()
if not skip_pre_modules:
named_layers.update(self.pre_module_structs)
if not skip_blocks:
for block in self.block_structs:
named_layers[block.name] = block
if not skip_post_modules:
named_layers.update(self.post_module_structs)
return named_layers
@staticmethod
def _default_construct(
module: tp.Union[PIPELINE_CLS, MODEL_CLS],
/,
parent: tp.Optional[BaseModuleStruct] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionModelStruct":
if isinstance(module, UNET_PIPELINE_CLS):
module = module.unet
elif isinstance(module, DIT_PIPELINE_CLS):
module = module.transformer
if isinstance(module, UNET_CLS):
return UNetStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
elif isinstance(module, DIT_CLS):
return DiTStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
raise NotImplementedError(f"Unsupported module type: {type(module)}")
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
unet_key_map = UNetStruct._get_default_key_map()
dit_key_map = DiTStruct._get_default_key_map()
flux_key_map = FluxStruct._get_default_key_map()
key_map: dict[str, set[str]] = defaultdict(set)
for rkey, keys in unet_key_map.items():
key_map[rkey].update(keys)
for rkey, keys in dit_key_map.items():
key_map[rkey].update(keys)
for rkey, keys in flux_key_map.items():
key_map[rkey].update(keys)
return {k: v for k, v in key_map.items() if v}
@staticmethod
def _simplify_keys(keys: tp.Iterable[str], *, key_map: dict[str, set[str]]) -> list[str]:
"""Simplify the keys based on the key map.
Args:
keys (`Iterable[str]`):
The keys to simplify.
key_map (`dict[str, set[str]]`):
The key map.
Returns:
`list[str]`:
The simplified keys.
"""
# we first sort key_map by length of values in descending order
key_map = dict(sorted(key_map.items(), key=lambda item: len(item[1]), reverse=True))
ukeys, skeys = set(keys), set()
for k, v in key_map.items():
if k in ukeys:
skeys.add(k)
ukeys.discard(k)
ukeys.difference_update(v)
continue
if ukeys.issuperset(v):
skeys.add(k)
ukeys.difference_update(v)
assert not ukeys, f"Unrecognized keys: {ukeys}"
return sorted(skeys)
@dataclass(kw_only=True)
class DiffusionAttentionStruct(AttentionStruct):
module: Attention = field(repr=False, kw_only=False)
"""the module of AttentionBlock"""
parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)
def filter_kwargs(self, kwargs: dict) -> dict:
"""Filter layer kwargs to attn kwargs."""
if isinstance(self.parent.module, BasicTransformerBlock):
if kwargs.get("cross_attention_kwargs", None) is None:
attn_kwargs = {}
else:
attn_kwargs = dict(kwargs["cross_attention_kwargs"].items())
attn_kwargs.pop("gligen", None)
if self.idx == 0:
attn_kwargs["attention_mask"] = kwargs.get("attention_mask", None)
else:
attn_kwargs["attention_mask"] = kwargs.get("encoder_attention_mask", None)
else:
attn_kwargs = {}
return attn_kwargs
@staticmethod
def _default_construct(
module: Attention,
/,
parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionAttentionStruct":
if module.is_cross_attention:
q_proj, k_proj, v_proj = module.to_q, None, None
add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
else:
q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
add_q_proj = getattr(module, "add_q_proj", None)
add_k_proj = getattr(module, "add_k_proj", None)
add_v_proj = getattr(module, "add_v_proj", None)
add_o_proj = getattr(module, "to_add_out", None)
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
add_o_proj_rname = "to_add_out"
if getattr(module, "to_out", None) is not None:
o_proj = module.to_out[0]
o_proj_rname = "to_out.0"
assert isinstance(o_proj, nn.Linear)
elif parent is not None:
assert isinstance(parent.module, FluxSingleTransformerBlock)
assert isinstance(parent.module.proj_out, ConcatLinear)
assert len(parent.module.proj_out.linears) == 2
o_proj = parent.module.proj_out.linears[0]
o_proj_rname = ".proj_out.linears.0"
else:
raise RuntimeError("Cannot find the output projection.")
if isinstance(module.processor, DiffusionAttentionProcessor):
with_rope = module.processor.rope is not None
elif module.processor.__class__.__name__.startswith("Flux"):
with_rope = True
else:
with_rope = False # TODO: fix for other processors
config = AttentionConfigStruct(
hidden_size=q_proj.weight.shape[1],
add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
inner_size=q_proj.weight.shape[0],
num_query_heads=module.heads,
num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
with_qk_norm=module.norm_q is not None,
with_rope=with_rope,
linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
)
return DiffusionAttentionStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
o_proj=o_proj,
add_q_proj=add_q_proj,
add_k_proj=add_k_proj,
add_v_proj=add_v_proj,
add_o_proj=add_o_proj,
q=None, # TODO: add q, k, v
k=None,
v=None,
q_proj_rname=q_proj_rname,
k_proj_rname=k_proj_rname,
v_proj_rname=v_proj_rname,
o_proj_rname=o_proj_rname,
add_q_proj_rname=add_q_proj_rname,
add_k_proj_rname=add_k_proj_rname,
add_v_proj_rname=add_v_proj_rname,
add_o_proj_rname=add_o_proj_rname,
q_rname="",
k_rname="",
v_rname="",
)
@dataclass(kw_only=True)
class DiffusionFeedForwardStruct(FeedForwardStruct):
module: FeedForward = field(repr=False, kw_only=False)
"""the module of FeedForward"""
parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)
# region modules
moe_gate: None = field(init=False, repr=False, default=None)
experts: list[nn.Module] = field(init=False, repr=False)
# endregion
# region names
moe_gate_rname: str = field(init=False, repr=False, default="")
experts_rname: str = field(init=False, repr=False, default="")
# endregion
# region aliases
@property
def up_proj(self) -> nn.Linear:
return self.up_projs[0]
@property
def down_proj(self) -> nn.Linear:
return self.down_projs[0]
@property
def up_proj_rname(self) -> str:
return self.up_proj_rnames[0]
@property
def down_proj_rname(self) -> str:
return self.down_proj_rnames[0]
@property
def up_proj_name(self) -> str:
return self.up_proj_names[0]
@property
def down_proj_name(self) -> str:
return self.down_proj_names[0]
# endregion
def __post_init__(self) -> None:
assert len(self.up_projs) == len(self.down_projs) == 1
assert len(self.up_proj_rnames) == len(self.down_proj_rnames) == 1
self.experts = [self.module]
super().__post_init__()
@staticmethod
def _default_construct(
module: FeedForward | FluxSingleTransformerBlock | GLUMBConv,
/,
parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionFeedForwardStruct":
if isinstance(module, FeedForward):
layer_1, layer_2 = module.net[0], module.net[2]
assert isinstance(layer_1, (GEGLU, GELU, ApproximateGELU, SwiGLU))
up_proj, up_proj_rname = layer_1.proj, "net.0.proj"
assert isinstance(up_proj, nn.Linear)
down_proj, down_proj_rname = layer_2, "net.2"
if isinstance(layer_1, GEGLU):
act_type = "gelu_glu"
elif isinstance(layer_1, SwiGLU):
act_type = "swish_glu"
else:
assert layer_1.__class__.__name__.lower().endswith("gelu")
act_type = "gelu"
if isinstance(layer_2, ShiftedLinear):
down_proj, down_proj_rname = layer_2.linear, "net.2.linear"
act_type = "gelu_shifted"
assert isinstance(down_proj, nn.Linear)
ffn = module
elif isinstance(module, FluxSingleTransformerBlock):
up_proj, up_proj_rname = module.proj_mlp, "proj_mlp"
act_type = "gelu"
assert isinstance(module.proj_out, ConcatLinear)
assert len(module.proj_out.linears) == 2
layer_2 = module.proj_out.linears[1]
if isinstance(layer_2, ShiftedLinear):
down_proj, down_proj_rname = layer_2.linear, "proj_out.linears.1.linear"
act_type = "gelu_shifted"
else:
down_proj, down_proj_rname = layer_2, "proj_out.linears.1"
ffn = nn.Sequential(up_proj, module.act_mlp, layer_2)
assert not rname, f"Unsupported rname: {rname}"
elif isinstance(module, GLUMBConv):
ffn = module
up_proj, up_proj_rname = module.conv_inverted, "conv_inverted"
down_proj, down_proj_rname = module.conv_point, "conv_point"
act_type = "silu_conv_silu_glu"
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
config = FeedForwardConfigStruct(
hidden_size=up_proj.weight.shape[1],
intermediate_size=down_proj.weight.shape[1],
intermediate_act_type=act_type,
num_experts=1,
)
return DiffusionFeedForwardStruct(
module=ffn, # this may be a virtual module
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
up_projs=[up_proj],
down_projs=[down_proj],
up_proj_rnames=[up_proj_rname],
down_proj_rnames=[down_proj_rname],
)
@dataclass(kw_only=True)
class DiffusionTransformerBlockStruct(TransformerBlockStruct, DiffusionBlockStruct):
# region relative keys
norm_rkey: tp.ClassVar[str] = "transformer_norm"
add_norm_rkey: tp.ClassVar[str] = "transformer_add_norm"
attn_struct_cls: tp.ClassVar[type[DiffusionAttentionStruct]] = DiffusionAttentionStruct
ffn_struct_cls: tp.ClassVar[type[DiffusionFeedForwardStruct]] = DiffusionFeedForwardStruct
# endregion
parent: tp.Optional["DiffusionTransformerStruct"] = field(repr=False)
# region child modules
post_attn_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
post_attn_add_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
post_ffn_norm: None = field(init=False, repr=False, default=None)
post_add_ffn_norm: None = field(init=False, repr=False, default=None)
# endregion
# region relative names
post_attn_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
post_attn_add_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
post_ffn_norm_rname: str = field(init=False, repr=False, default="")
post_add_ffn_norm_rname: str = field(init=False, repr=False, default="")
# endregion
# region attributes
norm_type: str
add_norm_type: str
# endregion
# region absolute keys
norm_key: str = field(init=False, repr=False)
add_norm_key: str = field(init=False, repr=False)
# endregion
# region child structs
pre_attn_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
pre_attn_add_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
pre_ffn_norm_struct: DiffusionModuleStruct = field(init=False, repr=False, default=None)
pre_add_ffn_norm_struct: DiffusionModuleStruct | None = field(init=False, repr=False, default=None)
attn_structs: list[DiffusionAttentionStruct] = field(init=False, repr=False)
ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
add_ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
# endregion
def __post_init__(self) -> None:
super().__post_init__()
self.norm_key = join_name(self.key, self.norm_rkey, sep="_")
self.add_norm_key = join_name(self.key, self.add_norm_rkey, sep="_")
self.attn_norm_structs = [
DiffusionModuleStruct(norm, parent=self, fname="pre_attn_norm", rname=rname, rkey=self.norm_rkey, idx=idx)
for idx, (norm, rname) in enumerate(zip(self.pre_attn_norms, self.pre_attn_norm_rnames, strict=True))
]
self.add_attn_norm_structs = [
DiffusionModuleStruct(
norm, parent=self, fname="pre_attn_add_norm", rname=rname, rkey=self.add_norm_rkey, idx=idx
)
for idx, (norm, rname) in enumerate(
zip(self.pre_attn_add_norms, self.pre_attn_add_norm_rnames, strict=True)
)
]
if self.pre_ffn_norm is not None:
self.pre_ffn_norm_struct = DiffusionModuleStruct(
self.pre_ffn_norm, parent=self, fname="pre_ffn_norm", rname=self.pre_ffn_norm_rname, rkey=self.norm_rkey
)
if self.pre_add_ffn_norm is not None:
self.pre_add_ffn_norm_struct = DiffusionModuleStruct(
self.pre_add_ffn_norm,
parent=self,
fname="pre_add_ffn_norm",
rname=self.pre_add_ffn_norm_rname,
rkey=self.add_norm_rkey,
)
def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
for attn_norm in self.attn_norm_structs:
if attn_norm.module is not None:
yield from attn_norm.named_key_modules()
for add_attn_norm in self.add_attn_norm_structs:
if add_attn_norm.module is not None:
yield from add_attn_norm.named_key_modules()
for attn_struct in self.attn_structs:
yield from attn_struct.named_key_modules()
if self.pre_ffn_norm_struct is not None:
if self.pre_attn_norms and self.pre_attn_norms[0] is not self.pre_ffn_norm:
yield from self.pre_ffn_norm_struct.named_key_modules()
if self.ffn_struct is not None:
yield from self.ffn_struct.named_key_modules()
if self.pre_add_ffn_norm_struct is not None:
if self.pre_attn_add_norms and self.pre_attn_add_norms[0] is not self.pre_add_ffn_norm:
yield from self.pre_add_ffn_norm_struct.named_key_modules()
if self.add_ffn_struct is not None:
yield from self.add_ffn_struct.named_key_modules()
@staticmethod
def _default_construct(
module: DIT_BLOCK_CLS,
/,
parent: tp.Optional["DiffusionTransformerStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionTransformerBlockStruct":
if isinstance(module, (BasicTransformerBlock, SanaTransformerBlock)):
parallel = False
if isinstance(module, SanaTransformerBlock):
norm_type = add_norm_type = "ada_norm_single"
else:
norm_type = add_norm_type = module.norm_type
pre_attn_norms, pre_attn_norm_rnames = [], []
attns, attn_rnames = [], []
pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
assert module.norm1 is not None
assert module.attn1 is not None
pre_attn_norms.append(module.norm1)
pre_attn_norm_rnames.append("norm1")
attns.append(module.attn1)
attn_rnames.append("attn1")
pre_attn_add_norms.append(module.attn1.norm_cross)
pre_attn_add_norm_rnames.append("attn1.norm_cross")
if module.attn2 is not None:
if norm_type == "ada_norm_single":
pre_attn_norms.append(None)
pre_attn_norm_rnames.append("")
else:
assert module.norm2 is not None
pre_attn_norms.append(module.norm2)
pre_attn_norm_rnames.append("norm2")
attns.append(module.attn2)
attn_rnames.append("attn2")
pre_attn_add_norms.append(module.attn2.norm_cross)
pre_attn_add_norm_rnames.append("attn2.norm_cross")
if norm_type == "ada_norm_single":
assert module.norm2 is not None
pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
else:
pre_ffn_norm, pre_ffn_norm_rname = module.norm3, "" if module.norm3 is None else "norm3"
ffn, ffn_rname = module.ff, "" if module.ff is None else "ff"
pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
elif isinstance(module, JointTransformerBlock):
parallel = False
norm_type = "ada_norm_zero"
pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
if isinstance(module.norm1_context, AdaLayerNormZero):
add_norm_type = "ada_norm_zero"
else:
add_norm_type = "ada_norm_continous"
pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
attns, attn_rnames = [module.attn], ["attn"]
pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
ffn, ffn_rname = module.ff, "ff"
pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
add_ffn, add_ffn_rname = module.ff_context, "ff_context"
elif isinstance(module, FluxSingleTransformerBlock):
parallel = True
norm_type = add_norm_type = "ada_norm_zero_single"
pre_attn_norms, pre_attn_norm_rnames = [module.norm], ["norm"]
attns, attn_rnames = [module.attn], ["attn"]
pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
pre_ffn_norm, pre_ffn_norm_rname = module.norm, "norm"
ffn, ffn_rname = module, ""
pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
elif isinstance(module, FluxTransformerBlock):
parallel = False
norm_type = add_norm_type = "ada_norm_zero"
pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
attns, attn_rnames = [module.attn], ["attn"]
pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
ffn, ffn_rname = module.ff, "ff"
pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
add_ffn, add_ffn_rname = module.ff_context, "ff_context"
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
return DiffusionTransformerBlockStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
parallel=parallel,
pre_attn_norms=pre_attn_norms,
pre_attn_add_norms=pre_attn_add_norms,
attns=attns,
pre_ffn_norm=pre_ffn_norm,
ffn=ffn,
pre_add_ffn_norm=pre_add_ffn_norm,
add_ffn=add_ffn,
pre_attn_norm_rnames=pre_attn_norm_rnames,
pre_attn_add_norm_rnames=pre_attn_add_norm_rnames,
attn_rnames=attn_rnames,
pre_ffn_norm_rname=pre_ffn_norm_rname,
ffn_rname=ffn_rname,
pre_add_ffn_norm_rname=pre_add_ffn_norm_rname,
add_ffn_rname=add_ffn_rname,
norm_type=norm_type,
add_norm_type=add_norm_type,
)
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
norm_rkey = norm_key = cls.norm_rkey
add_norm_rkey = add_norm_key = cls.add_norm_rkey
key_map[norm_rkey].add(norm_key)
key_map[add_norm_rkey].add(add_norm_key)
attn_cls = cls.attn_struct_cls
attn_key = attn_rkey = cls.attn_rkey
qkv_proj_key = qkv_proj_rkey = join_name(attn_key, attn_cls.qkv_proj_rkey, sep="_")
out_proj_key = out_proj_rkey = join_name(attn_key, attn_cls.out_proj_rkey, sep="_")
add_qkv_proj_key = add_qkv_proj_rkey = join_name(attn_key, attn_cls.add_qkv_proj_rkey, sep="_")
add_out_proj_key = add_out_proj_rkey = join_name(attn_key, attn_cls.add_out_proj_rkey, sep="_")
key_map[attn_rkey].add(qkv_proj_key)
key_map[attn_rkey].add(out_proj_key)
if attn_cls.add_qkv_proj_rkey.startswith("add_") and attn_cls.add_out_proj_rkey.startswith("add_"):
add_attn_rkey = join_name(attn_rkey, "add", sep="_")
key_map[add_attn_rkey].add(add_qkv_proj_key)
key_map[add_attn_rkey].add(add_out_proj_key)
key_map[qkv_proj_rkey].add(qkv_proj_key)
key_map[out_proj_rkey].add(out_proj_key)
key_map[add_qkv_proj_rkey].add(add_qkv_proj_key)
key_map[add_out_proj_rkey].add(add_out_proj_key)
ffn_cls = cls.ffn_struct_cls
ffn_key = ffn_rkey = cls.ffn_rkey
add_ffn_key = add_ffn_rkey = cls.add_ffn_rkey
up_proj_key = up_proj_rkey = join_name(ffn_key, ffn_cls.up_proj_rkey, sep="_")
down_proj_key = down_proj_rkey = join_name(ffn_key, ffn_cls.down_proj_rkey, sep="_")
add_up_proj_key = add_up_proj_rkey = join_name(add_ffn_key, ffn_cls.up_proj_rkey, sep="_")
add_down_proj_key = add_down_proj_rkey = join_name(add_ffn_key, ffn_cls.down_proj_rkey, sep="_")
key_map[ffn_rkey].add(up_proj_key)
key_map[ffn_rkey].add(down_proj_key)
key_map[add_ffn_rkey].add(add_up_proj_key)
key_map[add_ffn_rkey].add(add_down_proj_key)
key_map[up_proj_rkey].add(up_proj_key)
key_map[down_proj_rkey].add(down_proj_key)
key_map[add_up_proj_rkey].add(add_up_proj_key)
key_map[add_down_proj_rkey].add(add_down_proj_key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class DiffusionTransformerStruct(BaseTransformerStruct, DiffusionBlockStruct):
# region relative keys
proj_in_rkey: tp.ClassVar[str] = "transformer_proj_in"
proj_out_rkey: tp.ClassVar[str] = "transformer_proj_out"
transformer_block_rkey: tp.ClassVar[str] = ""
transformer_block_struct_cls: tp.ClassVar[type[DiffusionTransformerBlockStruct]] = DiffusionTransformerBlockStruct
# endregion
module: Transformer2DModel = field(repr=False, kw_only=False)
# region modules
norm_in: nn.GroupNorm | None
"""Input normalization"""
proj_in: nn.Linear | nn.Conv2d
"""Input projection"""
norm_out: nn.GroupNorm | None
"""Output normalization"""
proj_out: nn.Linear | nn.Conv2d
"""Output projection"""
transformer_blocks: nn.ModuleList = field(repr=False)
"""Transformer blocks"""
# endregion
# region relative names
transformer_blocks_rname: str
# endregion
# region absolute names
transformer_blocks_name: str = field(init=False, repr=False)
transformer_block_names: list[str] = field(init=False, repr=False)
# endregion
# region child structs
transformer_block_structs: list[DiffusionTransformerBlockStruct] = field(init=False, repr=False)
# endregion
# region aliases
@property
def num_blocks(self) -> int:
return len(self.transformer_blocks)
@property
def block_structs(self) -> list[DiffusionBlockStruct]:
return self.transformer_block_structs
@property
def block_names(self) -> list[str]:
return self.transformer_block_names
# endregion
def __post_init__(self):
super().__post_init__()
transformer_block_rnames = [
f"{self.transformer_blocks_rname}.{idx}" for idx in range(len(self.transformer_blocks))
]
self.transformer_blocks_name = join_name(self.name, self.transformer_blocks_rname)
self.transformer_block_names = [join_name(self.name, rname) for rname in transformer_block_rnames]
self.transformer_block_structs = [
self.transformer_block_struct_cls.construct(
layer,
parent=self,
fname="transformer_block",
rname=rname,
rkey=self.transformer_block_rkey,
idx=idx,
)
for idx, (layer, rname) in enumerate(zip(self.transformer_blocks, transformer_block_rnames, strict=True))
]
@staticmethod
def _default_construct(
module: Transformer2DModel,
/,
parent: BaseModuleStruct = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionTransformerStruct":
if isinstance(module, Transformer2DModel):
assert module.is_input_continuous, "input must be continuous"
transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
norm_in, norm_in_rname = module.norm, "norm"
proj_in, proj_in_rname = module.proj_in, "proj_in"
proj_out, proj_out_rname = module.proj_out, "proj_out"
norm_out, norm_out_rname = None, ""
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
return DiffusionTransformerStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
norm_in=norm_in,
proj_in=proj_in,
transformer_blocks=transformer_blocks,
proj_out=proj_out,
norm_out=norm_out,
norm_in_rname=norm_in_rname,
proj_in_rname=proj_in_rname,
transformer_blocks_rname=transformer_blocks_rname,
norm_out_rname=norm_out_rname,
proj_out_rname=proj_out_rname,
)
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
proj_in_rkey = proj_in_key = cls.proj_in_rkey
proj_out_rkey = proj_out_key = cls.proj_out_rkey
key_map[proj_in_rkey].add(proj_in_key)
key_map[proj_out_rkey].add(proj_out_key)
block_cls = cls.transformer_block_struct_cls
block_key = block_rkey = cls.transformer_block_rkey
block_key_map = block_cls._get_default_key_map()
for rkey, keys in block_key_map.items():
rkey = join_name(block_rkey, rkey, sep="_")
for key in keys:
key = join_name(block_key, key, sep="_")
key_map[rkey].add(key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class DiffusionResnetStruct(BaseModuleStruct):
# region relative keys
conv_rkey: tp.ClassVar[str] = "conv"
shortcut_rkey: tp.ClassVar[str] = "shortcut"
time_proj_rkey: tp.ClassVar[str] = "time_proj"
# endregion
module: ResnetBlock2D = field(repr=False, kw_only=False)
"""the module of Resnet"""
config: FeedForwardConfigStruct
# region child modules
norms: list[nn.GroupNorm]
convs: list[list[nn.Conv2d]]
shortcut: nn.Conv2d | None
time_proj: nn.Linear | None
# endregion
# region relative names
norm_rnames: list[str]
conv_rnames: list[list[str]]
shortcut_rname: str
time_proj_rname: str
# endregion
# region absolute names
norm_names: list[str] = field(init=False, repr=False)
conv_names: list[list[str]] = field(init=False, repr=False)
shortcut_name: str = field(init=False, repr=False)
time_proj_name: str = field(init=False, repr=False)
# endregion
# region absolute keys
conv_key: str = field(init=False, repr=False)
shortcut_key: str = field(init=False, repr=False)
time_proj_key: str = field(init=False, repr=False)
# endregion
def __post_init__(self):
super().__post_init__()
self.norm_names = [join_name(self.name, rname) for rname in self.norm_rnames]
self.conv_names = [[join_name(self.name, rname) for rname in rnames] for rnames in self.conv_rnames]
self.shortcut_name = join_name(self.name, self.shortcut_rname)
self.time_proj_name = join_name(self.name, self.time_proj_rname)
self.conv_key = join_name(self.key, self.conv_rkey, sep="_")
self.shortcut_key = join_name(self.key, self.shortcut_rkey, sep="_")
self.time_proj_key = join_name(self.key, self.time_proj_rkey, sep="_")
def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
for convs, names in zip(self.convs, self.conv_names, strict=True):
for conv, name in zip(convs, names, strict=True):
yield self.conv_key, name, conv, self, "conv"
if self.shortcut is not None:
yield self.shortcut_key, self.shortcut_name, self.shortcut, self, "shortcut"
if self.time_proj is not None:
yield self.time_proj_key, self.time_proj_name, self.time_proj, self, "time_proj"
@staticmethod
def construct(
module: ResnetBlock2D,
/,
parent: BaseModuleStruct = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionResnetStruct":
if isinstance(module, ResnetBlock2D):
assert module.upsample is None, "upsample must be None"
assert module.downsample is None, "downsample must be None"
act_type = module.nonlinearity.__class__.__name__.lower()
shifted = False
if isinstance(module.conv1, ConcatConv2d):
conv1_convs, conv1_names = [], []
for conv_idx, conv in enumerate(module.conv1.convs):
if isinstance(conv, ShiftedConv2d):
shifted = True
conv1_convs.append(conv.conv)
conv1_names.append(f"conv1.convs.{conv_idx}.conv")
else:
assert isinstance(conv, nn.Conv2d)
conv1_convs.append(conv)
conv1_names.append(f"conv1.convs.{conv_idx}")
elif isinstance(module.conv1, ShiftedConv2d):
shifted = True
conv1_convs = [module.conv1.conv]
conv1_names = ["conv1.conv"]
else:
assert isinstance(module.conv1, nn.Conv2d)
conv1_convs, conv1_names = [module.conv1], ["conv1"]
if isinstance(module.conv2, ConcatConv2d):
conv2_convs, conv2_names = [], []
for conv_idx, conv in enumerate(module.conv2.convs):
if isinstance(conv, ShiftedConv2d):
shifted = True
conv2_convs.append(conv.conv)
conv2_names.append(f"conv2.convs.{conv_idx}.conv")
else:
assert isinstance(conv, nn.Conv2d)
conv2_convs.append(conv)
conv2_names.append(f"conv2.convs.{conv_idx}")
elif isinstance(module.conv2, ShiftedConv2d):
shifted = True
conv2_convs = [module.conv2.conv]
conv2_names = ["conv2.conv"]
else:
assert isinstance(module.conv2, nn.Conv2d)
conv2_convs, conv2_names = [module.conv2], ["conv2"]
convs, conv_rnames = [conv1_convs, conv2_convs], [conv1_names, conv2_names]
norms, norm_rnames = [module.norm1, module.norm2], ["norm1", "norm2"]
shortcut, shortcut_rname = module.conv_shortcut, "" if module.conv_shortcut is None else "conv_shortcut"
time_proj, time_proj_rname = module.time_emb_proj, "" if module.time_emb_proj is None else "time_emb_proj"
if shifted:
assert all(hasattr(conv, "shifted") and conv.shifted for level_convs in convs for conv in level_convs)
act_type += "_shifted"
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
config = FeedForwardConfigStruct(
hidden_size=convs[0][0].weight.shape[1],
intermediate_size=convs[0][0].weight.shape[0],
intermediate_act_type=act_type,
num_experts=1,
)
return DiffusionResnetStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
norms=norms,
convs=convs,
shortcut=shortcut,
time_proj=time_proj,
norm_rnames=norm_rnames,
conv_rnames=conv_rnames,
shortcut_rname=shortcut_rname,
time_proj_rname=time_proj_rname,
)
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
conv_key = conv_rkey = cls.conv_rkey
shortcut_key = shortcut_rkey = cls.shortcut_rkey
time_proj_key = time_proj_rkey = cls.time_proj_rkey
key_map[conv_rkey].add(conv_key)
key_map[shortcut_rkey].add(shortcut_key)
key_map[time_proj_rkey].add(time_proj_key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class UNetBlockStruct(DiffusionBlockStruct):
class BlockType(enum.StrEnum):
DOWN = "down"
MID = "mid"
UP = "up"
# region relative keys
resnet_rkey: tp.ClassVar[str] = "resblock"
sampler_rkey: tp.ClassVar[str] = "sample"
transformer_rkey: tp.ClassVar[str] = ""
resnet_struct_cls: tp.ClassVar[type[DiffusionResnetStruct]] = DiffusionResnetStruct
transformer_struct_cls: tp.ClassVar[type[DiffusionTransformerStruct]] = DiffusionTransformerStruct
# endregion
parent: tp.Optional["UNetStruct"] = field(repr=False)
# region attributes
block_type: BlockType
# endregion
# region modules
resnets: nn.ModuleList = field(repr=False)
transformers: nn.ModuleList = field(repr=False)
sampler: nn.Conv2d | None
# endregion
# region relative names
resnets_rname: str
transformers_rname: str
sampler_rname: str
# endregion
# region absolute names
resnets_name: str = field(init=False, repr=False)
transformers_name: str = field(init=False, repr=False)
sampler_name: str = field(init=False, repr=False)
resnet_names: list[str] = field(init=False, repr=False)
transformer_names: list[str] = field(init=False, repr=False)
# endregion
# region absolute keys
sampler_key: str = field(init=False, repr=False)
# endregion
# region child structs
resnet_structs: list[DiffusionResnetStruct] = field(init=False, repr=False)
transformer_structs: list[DiffusionTransformerStruct] = field(init=False, repr=False)
# endregion
@property
def downsample(self) -> nn.Conv2d | None:
return self.sampler if self.is_downsample_block() else None
@property
def upsample(self) -> nn.Conv2d | None:
return self.sampler if self.is_upsample_block() else None
def __post_init__(self) -> None:
super().__post_init__()
if self.is_downsample_block():
assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
if self.parent is not None and isinstance(self.parent, UNetStruct):
assert self.rname == f"{self.parent.down_blocks_rname}.{self.idx}"
elif self.is_mid_block():
assert len(self.resnets) == len(self.transformers) + 1 or len(self.transformers) == 0
if self.parent is not None and isinstance(self.parent, UNetStruct):
assert self.rname == self.parent.mid_block_name
assert self.idx == 0
else:
assert self.is_upsample_block(), f"Unsupported block type: {self.block_type}"
assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
if self.parent is not None and isinstance(self.parent, UNetStruct):
assert self.rname == f"{self.parent.up_blocks_rname}.{self.idx}"
resnet_rnames = [f"{self.resnets_rname}.{idx}" for idx in range(len(self.resnets))]
transformer_rnames = [f"{self.transformers_rname}.{idx}" for idx in range(len(self.transformers))]
self.resnets_name = join_name(self.name, self.resnets_rname)
self.transformers_name = join_name(self.name, self.transformers_rname)
self.resnet_names = [join_name(self.name, rname) for rname in resnet_rnames]
self.transformer_names = [join_name(self.name, rname) for rname in transformer_rnames]
self.sampler_name = join_name(self.name, self.sampler_rname)
self.sampler_key = join_name(self.key, self.sampler_rkey, sep="_")
self.resnet_structs = [
self.resnet_struct_cls.construct(
resnet, parent=self, fname="resnet", rname=rname, rkey=self.resnet_rkey, idx=idx
)
for idx, (resnet, rname) in enumerate(zip(self.resnets, resnet_rnames, strict=True))
]
self.transformer_structs = [
self.transformer_struct_cls.construct(
transformer, parent=self, fname="transformer", rname=rname, rkey=self.transformer_rkey, idx=idx
)
for idx, (transformer, rname) in enumerate(zip(self.transformers, transformer_rnames, strict=True))
]
def is_downsample_block(self) -> bool:
return self.block_type == self.BlockType.DOWN
def is_mid_block(self) -> bool:
return self.block_type == self.BlockType.MID
def is_upsample_block(self) -> bool:
return self.block_type == self.BlockType.UP
def has_downsample(self) -> bool:
return self.is_downsample_block() and self.sampler is not None
def has_upsample(self) -> bool:
return self.is_upsample_block() and self.sampler is not None
def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
for resnet in self.resnet_structs:
yield from resnet.named_key_modules()
for transformer in self.transformer_structs:
yield from transformer.named_key_modules()
if self.sampler is not None:
yield self.sampler_key, self.sampler_name, self.sampler, self, "sampler"
def iter_attention_structs(self) -> tp.Generator[DiffusionAttentionStruct, None, None]:
for transformer in self.transformer_structs:
yield from transformer.iter_attention_structs()
def iter_transformer_block_structs(self) -> tp.Generator[DiffusionTransformerBlockStruct, None, None]:
for transformer in self.transformer_structs:
yield from transformer.iter_transformer_block_structs()
@staticmethod
def _default_construct(
module: UNET_BLOCK_CLS,
/,
parent: tp.Optional["UNetStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "UNetBlockStruct":
resnets, resnets_rname = module.resnets, "resnets"
if isinstance(module, (DownBlock2D, CrossAttnDownBlock2D)):
block_type = UNetBlockStruct.BlockType.DOWN
if isinstance(module, CrossAttnDownBlock2D) and module.attentions is not None:
transformers, transformers_rname = module.attentions, "attentions"
else:
transformers, transformers_rname = [], ""
if module.downsamplers is None:
sampler, sampler_rname = None, ""
else:
assert len(module.downsamplers) == 1
downsampler = module.downsamplers[0]
assert isinstance(downsampler, Downsample2D)
sampler, sampler_rname = downsampler.conv, "downsamplers.0.conv"
assert isinstance(sampler, nn.Conv2d)
elif isinstance(module, (UNetMidBlock2D, UNetMidBlock2DCrossAttn)):
block_type = UNetBlockStruct.BlockType.MID
if (isinstance(module, UNetMidBlock2DCrossAttn) or module.add_attention) and module.attentions is not None:
transformers, transformers_rname = module.attentions, "attentions"
else:
transformers, transformers_rname = [], ""
sampler, sampler_rname = None, ""
elif isinstance(module, (UpBlock2D, CrossAttnUpBlock2D)):
block_type = UNetBlockStruct.BlockType.UP
if isinstance(module, CrossAttnUpBlock2D) and module.attentions is not None:
transformers, transformers_rname = module.attentions, "attentions"
else:
transformers, transformers_rname = [], ""
if module.upsamplers is None:
sampler, sampler_rname = None, ""
else:
assert len(module.upsamplers) == 1
upsampler = module.upsamplers[0]
assert isinstance(upsampler, Upsample2D)
sampler, sampler_rname = upsampler.conv, "upsamplers.0.conv"
assert isinstance(sampler, nn.Conv2d)
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
return UNetBlockStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
block_type=block_type,
resnets=resnets,
transformers=transformers,
sampler=sampler,
resnets_rname=resnets_rname,
transformers_rname=transformers_rname,
sampler_rname=sampler_rname,
)
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
resnet_cls = cls.resnet_struct_cls
resnet_key = resnet_rkey = cls.resnet_rkey
resnet_key_map = resnet_cls._get_default_key_map()
for rkey, keys in resnet_key_map.items():
rkey = join_name(resnet_rkey, rkey, sep="_")
for key in keys:
key = join_name(resnet_key, key, sep="_")
key_map[rkey].add(key)
key_map[resnet_rkey].add(key)
transformer_cls = cls.transformer_struct_cls
transformer_key = transformer_rkey = cls.transformer_rkey
transformer_key_map = transformer_cls._get_default_key_map()
for rkey, keys in transformer_key_map.items():
trkey = join_name(transformer_rkey, rkey, sep="_")
for key in keys:
key = join_name(transformer_key, key, sep="_")
key_map[rkey].add(key)
key_map[trkey].add(key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class UNetStruct(DiffusionModelStruct):
# region relative keys
input_embed_rkey: tp.ClassVar[str] = "input_embed"
"""hidden_states = input_embed(hidden_states), e.g., conv_in"""
time_embed_rkey: tp.ClassVar[str] = "time_embed"
"""temb = time_embed(timesteps, hidden_states)"""
add_time_embed_rkey: tp.ClassVar[str] = "time_embed"
"""add_temb = add_time_embed(timesteps, encoder_hidden_states)"""
text_embed_rkey: tp.ClassVar[str] = "text_embed"
"""encoder_hidden_states = text_embed(encoder_hidden_states)"""
norm_out_rkey: tp.ClassVar[str] = "output_embed"
"""hidden_states = norm_out(hidden_states), e.g., conv_norm_out"""
proj_out_rkey: tp.ClassVar[str] = "output_embed"
"""hidden_states = output_embed(hidden_states), e.g., conv_out"""
down_block_rkey: tp.ClassVar[str] = "down"
mid_block_rkey: tp.ClassVar[str] = "mid"
up_block_rkey: tp.ClassVar[str] = "up"
down_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
mid_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
up_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
# endregion
# region child modules
# region pre-block modules
input_embed: nn.Conv2d
time_embed: TimestepEmbedding
"""Time embedding"""
add_time_embed: (
TextTimeEmbedding
| TextImageTimeEmbedding
| TimestepEmbedding
| ImageTimeEmbedding
| ImageHintTimeEmbedding
| None
)
"""Additional time embedding"""
text_embed: nn.Linear | ImageProjection | TextImageProjection | None
"""Text embedding"""
# region post-block modules
norm_out: nn.GroupNorm | None
proj_out: nn.Conv2d
# endregion
# endregion
down_blocks: nn.ModuleList = field(repr=False)
mid_block: nn.Module = field(repr=False)
up_blocks: nn.ModuleList = field(repr=False)
# endregion
# region relative names
input_embed_rname: str
time_embed_rname: str
add_time_embed_rname: str
text_embed_rname: str
norm_out_rname: str
proj_out_rname: str
down_blocks_rname: str
mid_block_rname: str
up_blocks_rname: str
# endregion
# region absolute names
input_embed_name: str = field(init=False, repr=False)
time_embed_name: str = field(init=False, repr=False)
add_time_embed_name: str = field(init=False, repr=False)
text_embed_name: str = field(init=False, repr=False)
norm_out_name: str = field(init=False, repr=False)
proj_out_name: str = field(init=False, repr=False)
down_blocks_name: str = field(init=False, repr=False)
mid_block_name: str = field(init=False, repr=False)
up_blocks_name: str = field(init=False, repr=False)
down_block_names: list[str] = field(init=False, repr=False)
up_block_names: list[str] = field(init=False, repr=False)
# endregion
# region absolute keys
input_embed_key: str = field(init=False, repr=False)
time_embed_key: str = field(init=False, repr=False)
add_time_embed_key: str = field(init=False, repr=False)
text_embed_key: str = field(init=False, repr=False)
norm_out_key: str = field(init=False, repr=False)
proj_out_key: str = field(init=False, repr=False)
# endregion
# region child structs
down_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
mid_block_struct: UNetBlockStruct = field(init=False, repr=False)
up_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
# endregion
@property
def num_down_blocks(self) -> int:
return len(self.down_blocks)
@property
def num_up_blocks(self) -> int:
return len(self.up_blocks)
@property
def num_blocks(self) -> int:
return self.num_down_blocks + 1 + self.num_up_blocks
@property
def block_structs(self) -> list[UNetBlockStruct]:
return [*self.down_block_structs, self.mid_block_struct, *self.up_block_structs]
def __post_init__(self) -> None:
super().__post_init__()
down_block_rnames = [f"{self.down_blocks_rname}.{idx}" for idx in range(len(self.down_blocks))]
up_block_rnames = [f"{self.up_blocks_rname}.{idx}" for idx in range(len(self.up_blocks))]
self.down_blocks_name = join_name(self.name, self.down_blocks_rname)
self.mid_block_name = join_name(self.name, self.mid_block_rname)
self.up_blocks_name = join_name(self.name, self.up_blocks_rname)
self.down_block_names = [join_name(self.name, rname) for rname in down_block_rnames]
self.up_block_names = [join_name(self.name, rname) for rname in up_block_rnames]
self.pre_module_structs = {}
for fname in ("time_embed", "add_time_embed", "text_embed", "input_embed"):
module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
if module is not None or rname:
setattr(self, f"{fname}_name", join_name(self.name, rname))
else:
setattr(self, f"{fname}_name", "")
if module is not None:
assert rname, f"rname of {fname} must not be empty"
self.pre_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
module=module, parent=self, fname=fname, rname=rname, rkey=rkey
)
self.post_module_structs = {}
for fname in ("norm_out", "proj_out"):
module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
if module is not None or rname:
setattr(self, f"{fname}_name", join_name(self.name, rname))
else:
setattr(self, f"{fname}_name", "")
if module is not None:
self.post_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
module=module, parent=self, fname=fname, rname=rname, rkey=rkey
)
self.down_block_structs = [
self.down_block_struct_cls.construct(
block, parent=self, fname="down_block", rname=rname, rkey=self.down_block_rkey, idx=idx
)
for idx, (block, rname) in enumerate(zip(self.down_blocks, down_block_rnames, strict=True))
]
self.mid_block_struct = self.mid_block_struct_cls.construct(
self.mid_block, parent=self, fname="mid_block", rname=self.mid_block_name, rkey=self.mid_block_rkey
)
self.up_block_structs = [
self.up_block_struct_cls.construct(
block, parent=self, fname="up_block", rname=rname, rkey=self.up_block_rkey, idx=idx
)
for idx, (block, rname) in enumerate(zip(self.up_blocks, up_block_rnames, strict=True))
]
def get_prev_module_keys(self) -> tuple[str, ...]:
return tuple({self.input_embed_key, self.time_embed_key, self.add_time_embed_key, self.text_embed_key})
def get_post_module_keys(self) -> tuple[str, ...]:
return tuple({self.norm_out_key, self.proj_out_key})
def _get_iter_block_activations_args(
self, **input_kwargs
) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
num_down_blocks = len(self.down_blocks)
num_up_blocks = len(self.up_blocks)
layers.extend(self.down_blocks)
layer_structs.extend(self.down_block_structs)
use_prev_layer_outputs.append(False)
use_prev_layer_outputs.extend([True] * (num_down_blocks - 1))
recomputes.append(False)
# region check whether down block's outputs are changed
_mid_block_additional_residual = input_kwargs.get("mid_block_additional_residual", None)
_down_block_additional_residuals = input_kwargs.get("down_block_additional_residuals", None)
_is_adapter = input_kwargs.get("down_intrablock_additional_residuals", None) is not None
if not _is_adapter and _mid_block_additional_residual is None and _down_block_additional_residuals is not None:
_is_adapter = True
for down_block in self.down_blocks:
if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
# outputs unchanged
recomputes.append(False)
elif _is_adapter:
# outputs changed
recomputes.append(True)
else:
# outputs unchanged
recomputes.append(False)
# endregion
layers.append(self.mid_block)
layer_structs.append(self.mid_block_struct)
use_prev_layer_outputs.append(False)
# recomputes is already appened in the previous down blocks
layers.extend(self.up_blocks)
layer_structs.extend(self.up_block_structs)
use_prev_layer_outputs.append(False)
use_prev_layer_outputs.extend([True] * (num_up_blocks - 1))
recomputes += [True] * num_up_blocks
return layers, layer_structs, recomputes, use_prev_layer_outputs
@staticmethod
def _default_construct(
module: tp.Union[UNET_PIPELINE_CLS, UNET_CLS],
/,
parent: tp.Optional[BaseModuleStruct] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "UNetStruct":
if isinstance(module, UNET_PIPELINE_CLS):
module = module.unet
if isinstance(module, (UNet2DConditionModel, UNet2DModel)):
input_embed, time_embed = module.conv_in, module.time_embedding
input_embed_rname, time_embed_rname = "conv_in", "time_embedding"
text_embed, text_embed_rname = None, ""
add_time_embed, add_time_embed_rname = None, ""
if hasattr(module, "encoder_hid_proj"):
text_embed, text_embed_rname = module.encoder_hid_proj, "encoder_hid_proj"
if hasattr(module, "add_embedding"):
add_time_embed, add_time_embed_rname = module.add_embedding, "add_embedding"
norm_out, norm_out_rname = module.conv_norm_out, "conv_norm_out"
proj_out, proj_out_rname = module.conv_out, "conv_out"
down_blocks, down_blocks_rname = module.down_blocks, "down_blocks"
mid_block, mid_block_rname = module.mid_block, "mid_block"
up_blocks, up_blocks_rname = module.up_blocks, "up_blocks"
return UNetStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
input_embed=input_embed,
time_embed=time_embed,
add_time_embed=add_time_embed,
text_embed=text_embed,
norm_out=norm_out,
proj_out=proj_out,
down_blocks=down_blocks,
mid_block=mid_block,
up_blocks=up_blocks,
input_embed_rname=input_embed_rname,
time_embed_rname=time_embed_rname,
add_time_embed_rname=add_time_embed_rname,
text_embed_rname=text_embed_rname,
norm_out_rname=norm_out_rname,
proj_out_rname=proj_out_rname,
down_blocks_rname=down_blocks_rname,
mid_block_rname=mid_block_rname,
up_blocks_rname=up_blocks_rname,
)
raise NotImplementedError(f"Unsupported module type: {type(module)}")
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
for idx, (block_key, block_cls) in enumerate(
(
(cls.down_block_rkey, cls.down_block_struct_cls),
(cls.mid_block_rkey, cls.mid_block_struct_cls),
(cls.up_block_rkey, cls.up_block_struct_cls),
)
):
block_key_map: dict[str, set[str]] = defaultdict(set)
if idx != 1:
sampler_key = join_name(block_key, block_cls.sampler_rkey, sep="_")
sampler_rkey = block_cls.sampler_rkey
block_key_map[sampler_rkey].add(sampler_key)
_block_key_map = block_cls._get_default_key_map()
for rkey, keys in _block_key_map.items():
for key in keys:
key = join_name(block_key, key, sep="_")
block_key_map[rkey].add(key)
for rkey, keys in block_key_map.items():
key_map[rkey].update(keys)
if block_key:
key_map[block_key].update(keys)
keys: set[str] = set()
keys.add(cls.input_embed_rkey)
keys.add(cls.time_embed_rkey)
keys.add(cls.add_time_embed_rkey)
keys.add(cls.text_embed_rkey)
keys.add(cls.norm_out_rkey)
keys.add(cls.proj_out_rkey)
for mapped_keys in key_map.values():
for key in mapped_keys:
keys.add(key)
if "embed" not in keys and "embed" not in key_map:
key_map["embed"].add(cls.input_embed_rkey)
key_map["embed"].add(cls.time_embed_rkey)
key_map["embed"].add(cls.add_time_embed_rkey)
key_map["embed"].add(cls.text_embed_rkey)
key_map["embed"].add(cls.norm_out_rkey)
key_map["embed"].add(cls.proj_out_rkey)
for key in keys:
if key in key_map:
key_map[key].clear()
key_map[key].add(key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class DiTStruct(DiffusionModelStruct, DiffusionTransformerStruct):
# region relative keys
input_embed_rkey: tp.ClassVar[str] = "input_embed"
"""hidden_states = input_embed(hidden_states), e.g., conv_in"""
time_embed_rkey: tp.ClassVar[str] = "time_embed"
"""temb = time_embed(timesteps)"""
text_embed_rkey: tp.ClassVar[str] = "text_embed"
"""encoder_hidden_states = text_embed(encoder_hidden_states)"""
norm_in_rkey: tp.ClassVar[str] = "input_embed"
"""hidden_states = norm_in(hidden_states)"""
proj_in_rkey: tp.ClassVar[str] = "input_embed"
"""hidden_states = proj_in(hidden_states)"""
norm_out_rkey: tp.ClassVar[str] = "output_embed"
"""hidden_states = norm_out(hidden_states)"""
proj_out_rkey: tp.ClassVar[str] = "output_embed"
"""hidden_states = proj_out(hidden_states)"""
transformer_block_rkey: tp.ClassVar[str] = ""
# endregion
# region child modules
input_embed: PatchEmbed
time_embed: AdaLayerNormSingle | CombinedTimestepTextProjEmbeddings | TimestepEmbedding
text_embed: PixArtAlphaTextProjection | nn.Linear
norm_in: None = field(init=False, repr=False, default=None)
proj_in: None = field(init=False, repr=False, default=None)
norm_out: nn.LayerNorm | AdaLayerNormContinuous | None
proj_out: nn.Linear
# endregion
# region relative names
input_embed_rname: str
time_embed_rname: str
text_embed_rname: str
norm_in_rname: str = field(init=False, repr=False, default="")
proj_in_rname: str = field(init=False, repr=False, default="")
norm_out_rname: str
proj_out_rname: str
# endregion
# region absolute names
input_embed_name: str = field(init=False, repr=False)
time_embed_name: str = field(init=False, repr=False)
text_embed_name: str = field(init=False, repr=False)
# endregion
# region absolute keys
input_embed_key: str = field(init=False, repr=False)
time_embed_key: str = field(init=False, repr=False)
text_embed_key: str = field(init=False, repr=False)
norm_out_key: str = field(init=False, repr=False)
# endregion
@property
def num_blocks(self) -> int:
return len(self.transformer_blocks)
@property
def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
return self.transformer_block_structs
@property
def block_names(self) -> list[str]:
return self.transformer_block_names
def __post_init__(self) -> None:
super().__post_init__()
self.pre_module_structs = {}
for fname in ("input_embed", "time_embed", "text_embed"):
module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
if module is not None or rname:
setattr(self, f"{fname}_name", join_name(self.name, rname))
else:
setattr(self, f"{fname}_name", "")
if module is not None:
self.pre_module_structs.setdefault(
getattr(self, f"{fname}_name"),
DiffusionModuleStruct(module=module, parent=self, fname=fname, rname=rname, rkey=rkey),
)
self.post_module_structs = {}
self.norm_out_key = join_name(self.key, self.norm_out_rkey, sep="_")
for fname in ("norm_out", "proj_out"):
module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
if module is not None:
self.post_module_structs.setdefault(
getattr(self, f"{fname}_name"),
DiffusionModuleStruct(module=module, parent=self, fname=fname, rname=rname, rkey=rkey),
)
def get_prev_module_keys(self) -> tuple[str, ...]:
return tuple({self.input_embed_key, self.time_embed_key, self.text_embed_key})
def get_post_module_keys(self) -> tuple[str, ...]:
return tuple({self.norm_out_key, self.proj_out_key})
def _get_iter_block_activations_args(
self, **input_kwargs
) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
"""
Get the arguments for iterating over the layers and their activations.
Args:
skip_pre_modules (`bool`):
Whether to skip the pre-modules
skip_post_modules (`bool`):
Whether to skip the post-modules
Returns:
`tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]`:
the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
"""
layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
layers.extend(self.transformer_blocks)
layer_structs.extend(self.transformer_block_structs)
use_prev_layer_outputs.append(False)
use_prev_layer_outputs.extend([True] * (len(self.transformer_blocks) - 1))
recomputes.extend([False] * len(self.transformer_blocks))
return layers, layer_structs, recomputes, use_prev_layer_outputs
@staticmethod
def _default_construct(
module: tp.Union[DIT_PIPELINE_CLS, DIT_CLS],
/,
parent: tp.Optional[BaseModuleStruct] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiTStruct":
if isinstance(module, DIT_PIPELINE_CLS):
module = module.transformer
if isinstance(module, FluxTransformer2DModel):
return FluxStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
else:
if isinstance(module, PixArtTransformer2DModel):
input_embed, input_embed_rname = module.pos_embed, "pos_embed"
time_embed, time_embed_rname = module.adaln_single, "adaln_single"
text_embed, text_embed_rname = module.caption_projection, "caption_projection"
norm_out, norm_out_rname = module.norm_out, "norm_out"
proj_out, proj_out_rname = module.proj_out, "proj_out"
transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
# ! in fact, `module.adaln_single.emb` is `time_embed`,
# ! `module.adaln_single.linear` is `transformer_norm`
# ! but since PixArt shares the `transformer_norm`, we categorize it as `time_embed`
elif isinstance(module, SanaTransformer2DModel):
input_embed, input_embed_rname = module.patch_embed, "patch_embed"
time_embed, time_embed_rname = module.time_embed, "time_embed"
text_embed, text_embed_rname = module.caption_projection, "caption_projection"
norm_out, norm_out_rname = module.norm_out, "norm_out"
proj_out, proj_out_rname = module.proj_out, "proj_out"
transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
elif isinstance(module, SD3Transformer2DModel):
input_embed, input_embed_rname = module.pos_embed, "pos_embed"
time_embed, time_embed_rname = module.time_text_embed, "time_text_embed"
text_embed, text_embed_rname = module.context_embedder, "context_embedder"
norm_out, norm_out_rname = module.norm_out, "norm_out"
proj_out, proj_out_rname = module.proj_out, "proj_out"
transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
else:
raise NotImplementedError(f"Unsupported module type: {type(module)}")
return DiTStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
input_embed=input_embed,
time_embed=time_embed,
text_embed=text_embed,
transformer_blocks=transformer_blocks,
norm_out=norm_out,
proj_out=proj_out,
input_embed_rname=input_embed_rname,
time_embed_rname=time_embed_rname,
text_embed_rname=text_embed_rname,
norm_out_rname=norm_out_rname,
proj_out_rname=proj_out_rname,
transformer_blocks_rname=transformer_blocks_rname,
)
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
block_cls = cls.transformer_block_struct_cls
block_key = block_rkey = cls.transformer_block_rkey
block_key_map = block_cls._get_default_key_map()
for rkey, keys in block_key_map.items():
brkey = join_name(block_rkey, rkey, sep="_")
for key in keys:
key = join_name(block_key, key, sep="_")
key_map[rkey].add(key)
key_map[brkey].add(key)
if block_rkey:
key_map[block_rkey].add(key)
keys: set[str] = set()
keys.add(cls.input_embed_rkey)
keys.add(cls.time_embed_rkey)
keys.add(cls.text_embed_rkey)
keys.add(cls.norm_in_rkey)
keys.add(cls.proj_in_rkey)
keys.add(cls.norm_out_rkey)
keys.add(cls.proj_out_rkey)
for mapped_keys in key_map.values():
for key in mapped_keys:
keys.add(key)
if "embed" not in keys and "embed" not in key_map:
key_map["embed"].add(cls.input_embed_rkey)
key_map["embed"].add(cls.time_embed_rkey)
key_map["embed"].add(cls.text_embed_rkey)
key_map["embed"].add(cls.norm_in_rkey)
key_map["embed"].add(cls.proj_in_rkey)
key_map["embed"].add(cls.norm_out_rkey)
key_map["embed"].add(cls.proj_out_rkey)
for key in keys:
if key in key_map:
key_map[key].clear()
key_map[key].add(key)
return {k: v for k, v in key_map.items() if v}
@dataclass(kw_only=True)
class FluxStruct(DiTStruct):
# region relative keys
single_transformer_block_rkey: tp.ClassVar[str] = ""
single_transformer_block_struct_cls: tp.ClassVar[type[DiffusionTransformerBlockStruct]] = (
DiffusionTransformerBlockStruct
)
# endregion
module: FluxTransformer2DModel = field(repr=False, kw_only=False)
"""the module of FluxTransformer2DModel"""
# region child modules
input_embed: nn.Linear
time_embed: CombinedTimestepGuidanceTextProjEmbeddings | CombinedTimestepTextProjEmbeddings
text_embed: nn.Linear
single_transformer_blocks: nn.ModuleList = field(repr=False)
# endregion
# region relative names
single_transformer_blocks_rname: str
# endregion
# region absolute names
single_transformer_blocks_name: str = field(init=False, repr=False)
single_transformer_block_names: list[str] = field(init=False, repr=False)
# endregion
# region child structs
single_transformer_block_structs: list[DiffusionTransformerBlockStruct] = field(init=False)
# endregion
@property
def num_blocks(self) -> int:
return len(self.transformer_block_structs) + len(self.single_transformer_block_structs)
@property
def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
return [*self.transformer_block_structs, *self.single_transformer_block_structs]
@property
def block_names(self) -> list[str]:
return [*self.transformer_block_names, *self.single_transformer_block_names]
def __post_init__(self) -> None:
super().__post_init__()
single_transformer_block_rnames = [
f"{self.single_transformer_blocks_rname}.{idx}" for idx in range(len(self.single_transformer_blocks))
]
self.single_transformer_blocks_name = join_name(self.name, self.single_transformer_blocks_rname)
self.single_transformer_block_names = [join_name(self.name, rname) for rname in single_transformer_block_rnames]
self.single_transformer_block_structs = [
self.single_transformer_block_struct_cls.construct(
block,
parent=self,
fname="single_transformer_block",
rname=rname,
rkey=self.single_transformer_block_rkey,
idx=idx,
)
for idx, (block, rname) in enumerate(
zip(self.single_transformer_blocks, single_transformer_block_rnames, strict=True)
)
]
def _get_iter_block_activations_args(
self, **input_kwargs
) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
layers, layer_structs, recomputes, use_prev_layer_outputs = super()._get_iter_block_activations_args()
layers.extend(self.single_transformer_blocks)
layer_structs.extend(self.single_transformer_block_structs)
use_prev_layer_outputs.append(False)
use_prev_layer_outputs.extend([True] * (len(self.single_transformer_blocks) - 1))
recomputes.extend([False] * len(self.single_transformer_blocks))
return layers, layer_structs, recomputes, use_prev_layer_outputs
@staticmethod
def _default_construct(
module: tp.Union[FluxPipeline, FluxControlPipeline, FluxTransformer2DModel],
/,
parent: tp.Optional[BaseModuleStruct] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "FluxStruct":
if isinstance(module, (FluxPipeline, FluxControlPipeline)):
module = module.transformer
if isinstance(module, FluxTransformer2DModel):
input_embed, time_embed, text_embed = module.x_embedder, module.time_text_embed, module.context_embedder
input_embed_rname, time_embed_rname, text_embed_rname = "x_embedder", "time_text_embed", "context_embedder"
norm_out, norm_out_rname = module.norm_out, "norm_out"
proj_out, proj_out_rname = module.proj_out, "proj_out"
transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
single_transformer_blocks = module.single_transformer_blocks
single_transformer_blocks_rname = "single_transformer_blocks"
return FluxStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
input_embed=input_embed,
time_embed=time_embed,
text_embed=text_embed,
transformer_blocks=transformer_blocks,
single_transformer_blocks=single_transformer_blocks,
norm_out=norm_out,
proj_out=proj_out,
input_embed_rname=input_embed_rname,
time_embed_rname=time_embed_rname,
text_embed_rname=text_embed_rname,
norm_out_rname=norm_out_rname,
proj_out_rname=proj_out_rname,
transformer_blocks_rname=transformer_blocks_rname,
single_transformer_blocks_rname=single_transformer_blocks_rname,
)
raise NotImplementedError(f"Unsupported module type: {type(module)}")
@classmethod
def _get_default_key_map(cls) -> dict[str, set[str]]:
"""Get the default allowed keys."""
key_map: dict[str, set[str]] = defaultdict(set)
for block_rkey, block_cls in (
(cls.transformer_block_rkey, cls.transformer_block_struct_cls),
(cls.single_transformer_block_rkey, cls.single_transformer_block_struct_cls),
):
block_key = block_rkey
block_key_map = block_cls._get_default_key_map()
for rkey, keys in block_key_map.items():
brkey = join_name(block_rkey, rkey, sep="_")
for key in keys:
key = join_name(block_key, key, sep="_")
key_map[rkey].add(key)
key_map[brkey].add(key)
if block_rkey:
key_map[block_rkey].add(key)
keys: set[str] = set()
keys.add(cls.input_embed_rkey)
keys.add(cls.time_embed_rkey)
keys.add(cls.text_embed_rkey)
keys.add(cls.norm_in_rkey)
keys.add(cls.proj_in_rkey)
keys.add(cls.norm_out_rkey)
keys.add(cls.proj_out_rkey)
for mapped_keys in key_map.values():
for key in mapped_keys:
keys.add(key)
if "embed" not in keys and "embed" not in key_map:
key_map["embed"].add(cls.input_embed_rkey)
key_map["embed"].add(cls.time_embed_rkey)
key_map["embed"].add(cls.text_embed_rkey)
key_map["embed"].add(cls.norm_in_rkey)
key_map["embed"].add(cls.proj_in_rkey)
key_map["embed"].add(cls.norm_out_rkey)
key_map["embed"].add(cls.proj_out_rkey)
for key in keys:
if key in key_map:
key_map[key].clear()
key_map[key].add(key)
return {k: v for k, v in key_map.items() if v}
DiffusionAttentionStruct.register_factory(Attention, DiffusionAttentionStruct._default_construct)
DiffusionFeedForwardStruct.register_factory(
(FeedForward, FluxSingleTransformerBlock, GLUMBConv), DiffusionFeedForwardStruct._default_construct
)
DiffusionTransformerBlockStruct.register_factory(DIT_BLOCK_CLS, DiffusionTransformerBlockStruct._default_construct)
UNetBlockStruct.register_factory(UNET_BLOCK_CLS, UNetBlockStruct._default_construct)
UNetStruct.register_factory(tp.Union[UNET_PIPELINE_CLS, UNET_CLS], UNetStruct._default_construct)
FluxStruct.register_factory(
tp.Union[FluxPipeline, FluxControlPipeline, FluxTransformer2DModel], FluxStruct._default_construct
)
DiTStruct.register_factory(tp.Union[DIT_PIPELINE_CLS, DIT_CLS], DiTStruct._default_construct)
DiffusionTransformerStruct.register_factory(Transformer2DModel, DiffusionTransformerStruct._default_construct)
DiffusionModelStruct.register_factory(tp.Union[PIPELINE_CLS, MODEL_CLS], DiffusionModelStruct._default_construct)
================================================
FILE: deepcompressor/app/diffusion/pipeline/__init__.py
================================================
# -*- coding: utf-8 -*-
from .config import DiffusionPipelineConfig
================================================
FILE: deepcompressor/app/diffusion/pipeline/config.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion pipeline configuration module."""
import gc
import typing as tp
from dataclasses import dataclass, field
import torch
from diffusers.pipelines import (
AutoPipelineForText2Image,
DiffusionPipeline,
FluxControlPipeline,
FluxFillPipeline,
SanaPipeline,
)
from omniconfig import configclass
from torch import nn
from transformers import PreTrainedModel, PreTrainedTokenizer, T5EncoderModel
from deepcompressor.data.utils.dtype import eval_dtype
from deepcompressor.quantizer.processor import Quantizer
from deepcompressor.utils import tools
from deepcompressor.utils.hooks import AccumBranchHook, ProcessHook
from ....nn.patch.linear import ConcatLinear, ShiftedLinear
from ....nn.patch.lowrank import LowRankBranch
from ..nn.patch import (
replace_fused_linear_with_concat_linear,
replace_up_block_conv_with_concat_conv,
shift_input_activations,
)
__all__ = ["DiffusionPipelineConfig"]
@configclass
@dataclass
class LoRAConfig:
"""LoRA configuration.
Args:
path (`str`):
The path of the LoRA branch.
weight_name (`str`):
The weight name of the LoRA branch.
alpha (`float`):
The alpha value of the LoRA branch.
"""
path: str
weight_name: str
alpha: float = 1.0
@configclass
@dataclass
class DiffusionPipelineConfig:
"""Diffusion pipeline configuration.
Args:
name (`str`):
The name of the pipeline.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The data type of the pipeline.
device (`str`, *optional*, defaults to `"cuda"`):
The device of the pipeline.
shift_activations (`bool`, *optional*, defaults to `False`):
Whether to shift activations.
"""
_pipeline_factories: tp.ClassVar[
dict[str, tp.Callable[[str, str, torch.dtype, torch.device, bool], DiffusionPipeline]]
] = {}
_text_extractors: tp.ClassVar[
dict[
str,
tp.Callable[
[DiffusionPipeline, tuple[type[PreTrainedModel], ...]],
list[tuple[str, PreTrainedModel, PreTrainedTokenizer]],
],
]
] = {}
name: str
path: str = ""
dtype: torch.dtype = field(
default_factory=lambda s=torch.float32: eval_dtype(s, with_quant_dtype=False, with_none=False)
)
device: str = "cuda"
shift_activations: bool = False
lora: LoRAConfig | None = None
family: str = field(init=False)
task: str = "text-to-image"
def __post_init__(self):
self.family = self.name.split("-")[0]
if self.name == "flux.1-canny-dev":
self.task = "canny-to-image"
elif self.name == "flux.1-depth-dev":
self.task = "depth-to-image"
elif self.name == "flux.1-fill-dev":
self.task = "inpainting"
def build(
self, *, dtype: str | torch.dtype | None = None, device: str | torch.device | None = None
) -> DiffusionPipeline:
"""Build the diffusion pipeline.
Args:
dtype (`str` or `torch.dtype`, *optional*):
The data type of the pipeline.
device (`str` or `torch.device`, *optional*):
The device of the pipeline.
Returns:
`DiffusionPipeline`:
The diffusion pipeline.
"""
if dtype is None:
dtype = self.dtype
if device is None:
device = self.device
_factory = self._pipeline_factories.get(self.name, self._default_build)
return _factory(
name=self.name, path=self.path, dtype=dtype, device=device, shift_activations=self.shift_activations
)
def extract_text_encoders(
self, pipeline: DiffusionPipeline, supported: tuple[type[PreTrainedModel], ...] = (T5EncoderModel,)
) -> list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]:
"""Extract the text encoders and tokenizers from the pipeline.
Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline.
supported (`tuple[type[PreTrainedModel], ...]`, *optional*, defaults to `(T5EncoderModel,)`):
The supported text encoder types. If not specified, all text encoders will be extracted.
Returns:
`list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`:
The list of text encoder name, model, and tokenizer.
"""
_extractor = self._text_extractors.get(self.name, self._default_extract_text_encoders)
return _extractor(pipeline, supported)
@classmethod
def register_pipeline_factory(
cls,
names: str | tuple[str, ...],
/,
factory: tp.Callable[[str, str, torch.dtype, torch.device, bool], DiffusionPipeline],
*,
overwrite: bool = False,
) -> None:
"""Register a pipeline factory.
Args:
names (`str` or `tuple[str, ...]`):
The name of the pipeline.
factory (`Callable[[str, str,torch.dtype, torch.device, bool], DiffusionPipeline]`):
The pipeline factory function.
overwrite (`bool`, *optional*, defaults to `False`):
Whether to overwrite the existing factory for the pipeline.
"""
if isinstance(names, str):
names = [names]
for name in names:
if name in cls._pipeline_factories and not overwrite:
raise ValueError(f"Pipeline factory {name} already exists.")
cls._pipeline_factories[name] = factory
@classmethod
def register_text_extractor(
cls,
names: str | tuple[str, ...],
/,
extractor: tp.Callable[
[DiffusionPipeline, tuple[type[PreTrainedModel], ...]],
list[tuple[str, PreTrainedModel, PreTrainedTokenizer]],
],
*,
overwrite: bool = False,
) -> None:
"""Register a text extractor.
Args:
names (`str` or `tuple[str, ...]`):
The name of the pipeline.
extractor (`Callable[[DiffusionPipeline], list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`):
The text extractor function.
overwrite (`bool`, *optional*, defaults to `False`):
Whether to overwrite the existing extractor for the pipeline.
"""
if isinstance(names, str):
names = [names]
for name in names:
if name in cls._text_extractors and not overwrite:
raise ValueError(f"Text extractor {name} already exists.")
cls._text_extractors[name] = extractor
def load_lora( # noqa: C901
self, pipeline: DiffusionPipeline, smooth_cache: dict[str, torch.Tensor] | None = None
) -> DiffusionPipeline:
smooth_cache = smooth_cache or {}
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
assert isinstance(model, nn.Module)
if self.lora is not None:
logger = tools.logging.getLogger(__name__)
logger.info(f"Load LoRA branches from {self.lora.path}")
lora_state_dict, alphas = pipeline.lora_state_dict(
self.lora.path, return_alphas=True, weight_name=self.lora.weight_name
)
tools.logging.Formatter.indent_inc()
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, ConcatLinear, ShiftedLinear)):
lora_a_key, lora_b_key = f"transformer.{name}.lora_A.weight", f"transformer.{name}.lora_B.weight"
if lora_a_key in lora_state_dict:
assert lora_b_key in lora_state_dict
logger.info(f"+ Load LoRA branch for {name}")
tools.logging.Formatter.indent_inc()
a = lora_state_dict.pop(lora_a_key)
b = lora_state_dict.pop(lora_b_key)
assert isinstance(a, torch.Tensor)
assert isinstance(b, torch.Tensor)
assert a.shape[1] == module.in_features
assert b.shape[0] == module.out_features
if isinstance(module, ConcatLinear):
logger.debug(
f"- split LoRA branch into {len(module.linears)} parts ({module.in_features_list})"
)
m_splits = module.linears
a_splits = a.split(module.in_features_list, dim=1)
b_splits = [b] * len(a_splits)
else:
m_splits, a_splits, b_splits = [module], [a], [b]
for m, a, b in zip(m_splits, a_splits, b_splits, strict=True):
assert a.shape[0] == b.shape[1]
if isinstance(m, ShiftedLinear):
s, m = m.shift, m.linear
logger.debug(f"- shift LoRA input by {s.item() if s.numel() == 1 else s}")
else:
s = None
assert isinstance(m, nn.Linear)
device, dtype = m.weight.device, m.weight.dtype
a, b = a.to(device=device, dtype=torch.float64), b.to(device=device, dtype=torch.float64)
if s is not None:
if s.numel() == 1:
s = torch.matmul(b, a.sum(dim=1).mul_(s.double())).mul_(self.lora.alpha)
else:
s = torch.matmul(b, torch.matmul(a, s.view(1, -1).double())).mul_(self.lora.alpha)
if hasattr(m, "in_smooth_cache_key"):
logger.debug(f"- smooth LoRA input using {m.in_smooth_cache_key} smooth scale")
ss = smooth_cache[m.in_smooth_cache_key].to(device=device, dtype=torch.float64)
a = a.mul_(ss.view(1, -1))
del ss
if hasattr(m, "out_smooth_cache_key"):
logger.debug(f"- smooth LoRA output using {m.out_smooth_cache_key} smooth scale")
ss = smooth_cache[m.out_smooth_cache_key].to(device=device, dtype=torch.float64)
b = b.div_(ss.view(-1, 1))
if s is not None:
s = s.div_(ss.view(-1))
del ss
branch_hook, quant_hook = None, None
for hook in m._forward_pre_hooks.values():
if isinstance(hook, AccumBranchHook) and isinstance(hook.branch, LowRankBranch):
branch_hook = hook
if isinstance(hook, ProcessHook) and isinstance(hook.processor, Quantizer):
quant_hook = hook
if branch_hook is not None:
logger.debug("- fuse with existing LoRA branch")
assert isinstance(branch_hook.branch, LowRankBranch)
_a = branch_hook.branch.a.weight.data
_b = branch_hook.branch.b.weight.data
if branch_hook.branch.alpha != self.lora.alpha:
a, b = a.to(dtype=dtype), b.mul_(self.lora.alpha).to(dtype=dtype)
_b = _b.to(dtype=torch.float64).mul_(branch_hook.branch.alpha).to(dtype=dtype)
alpha = 1
else:
a, b = a.to(dtype=dtype), b.to(dtype=dtype)
alpha = self.lora.alpha
branch_hook.branch = LowRankBranch(
m.in_features,
m.out_features,
rank=a.shape[0] + branch_hook.branch.rank,
alpha=alpha,
).to(device=device, dtype=dtype)
branch_hook.branch.a.weight.data[: a.shape[0], :] = a
branch_hook.branch.b.weight.data[:, : b.shape[1]] = b
branch_hook.branch.a.weight.data[a.shape[0] :, :] = _a
branch_hook.branch.b.weight.data[:, b.shape[1] :] = _b
else:
logger.debug("- create a new LoRA branch")
branch = LowRankBranch(
m.in_features, m.out_features, rank=a.shape[0], alpha=self.lora.alpha
)
branch = branch.to(device=device, dtype=dtype)
branch.a.weight.data.copy_(a.to(dtype=dtype))
branch.b.weight.data.copy_(b.to(dtype=dtype))
# low rank branch hook should be registered before the quantization hook
if quant_hook is not None:
logger.debug(f"- remove quantization hook from {name}")
quant_hook.remove(m)
logger.debug(f"- register LoRA branch to {name}")
branch.as_hook().register(m)
if quant_hook is not None:
logger.debug(f"- re-register quantization hook to {name}")
quant_hook.register(m)
if s is not None:
assert m.bias is not None
m.bias.data.copy_((m.bias.double().sub_(s)).to(dtype))
del m_splits, a_splits, b_splits, a, b, s
gc.collect()
torch.cuda.empty_cache()
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
if len(lora_state_dict) > 0:
logger.warning(f"Unused LoRA weights: {lora_state_dict.keys()}")
branches = nn.ModuleList()
for _, module in model.named_modules():
for hook in module._forward_hooks.values():
if isinstance(hook, AccumBranchHook) and isinstance(hook.branch, LowRankBranch):
branches.append(hook.branch)
model.register_module("_low_rank_branches", branches)
@staticmethod
def _default_build(
name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool
) -> DiffusionPipeline:
if not path:
if name == "sdxl":
path = "stabilityai/stable-diffusion-xl-base-1.0"
elif name == "sdxl-turbo":
path = "stabilityai/sdxl-turbo"
elif name == "pixart-sigma":
path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
elif name == "flux.1-dev":
path = "black-forest-labs/FLUX.1-dev"
elif name == "flux.1-canny-dev":
path = "black-forest-labs/FLUX.1-Canny-dev"
elif name == "flux.1-depth-dev":
path = "black-forest-labs/FLUX.1-Depth-dev"
elif name == "flux.1-fill-dev":
path = "black-forest-labs/FLUX.1-Fill-dev"
elif name == "flux.1-schnell":
path = "black-forest-labs/FLUX.1-schnell"
else:
raise ValueError(f"Path for {name} is not specified.")
if name in ["flux.1-canny-dev", "flux.1-depth-dev"]:
pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype)
elif name == "flux.1-fill-dev":
pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype)
elif name.startswith("sana-"):
if dtype == torch.bfloat16:
pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True)
pipeline.vae.to(dtype)
pipeline.text_encoder.to(dtype)
else:
pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
else:
pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)
pipeline = pipeline.to(device)
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
replace_fused_linear_with_concat_linear(model)
replace_up_block_conv_with_concat_conv(model)
if shift_activations:
shift_input_activations(model)
return pipeline
@staticmethod
def _default_extract_text_encoders(
pipeline: DiffusionPipeline, supported: tuple[type[PreTrainedModel], ...]
) -> list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]:
"""Extract the text encoders and tokenizers from the pipeline.
Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline.
supported (`tuple[type[PreTrainedModel], ...]`, *optional*, defaults to `(T5EncoderModel,)`):
The supported text encoder types. If not specified, all text encoders will be extracted.
Returns:
`list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`:
The list of text encoder name, model, and tokenizer.
"""
results: list[tuple[str, PreTrainedModel, PreTrainedTokenizer]] = []
for key in vars.__dict__.keys():
if key.startswith("text_encoder"):
suffix = key[len("text_encoder") :]
encoder, tokenizer = getattr(pipeline, f"text_encoder{suffix}"), getattr(pipeline, f"tokenizer{suffix}")
if not supported or isinstance(encoder, supported):
results.append((key, encoder, tokenizer))
return results
================================================
FILE: deepcompressor/app/diffusion/ptq.py
================================================
import gc
import json
import os
import pprint
import traceback
import torch
from diffusers import DiffusionPipeline
from deepcompressor.app.llm.nn.patch import patch_attention, patch_gemma_rms_norm
from deepcompressor.app.llm.ptq import ptq as llm_ptq
from deepcompressor.utils import tools
from .config import DiffusionPtqCacheConfig, DiffusionPtqRunConfig, DiffusionQuantCacheConfig, DiffusionQuantConfig
from .nn.struct import DiffusionModelStruct
from .quant import (
load_diffusion_weights_state_dict,
quantize_diffusion_activations,
quantize_diffusion_weights,
rotate_diffusion,
smooth_diffusion,
)
__all__ = ["ptq"]
def ptq( # noqa: C901
model: DiffusionModelStruct,
config: DiffusionQuantConfig,
cache: DiffusionPtqCacheConfig | None = None,
load_dirpath: str = "",
save_dirpath: str = "",
copy_on_save: bool = False,
save_model: bool = False,
) -> DiffusionModelStruct:
"""Post-training quantization of a diffusion model.
Args:
model (`DiffusionModelStruct`):
The diffusion model.
config (`DiffusionQuantConfig`):
The diffusion model post-training quantization configuration.
cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`):
The diffusion model quantization cache path configuration.
load_dirpath (`str`, *optional*, defaults to `""`):
The directory path to load the quantization checkpoint.
save_dirpath (`str`, *optional*, defaults to `""`):
The directory path to save the quantization checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the cache to the save directory.
save_model (`bool`, *optional*, defaults to `False`):
Whether to save the quantized model checkpoint.
Returns:
`DiffusionModelStruct`:
The quantized diffusion model.
"""
logger = tools.logging.getLogger(__name__)
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
quant_wgts = config.enabled_wgts
quant_ipts = config.enabled_ipts
quant_opts = config.enabled_opts
quant_acts = quant_ipts or quant_opts
quant = quant_wgts or quant_acts
load_model_path, load_path, save_path = "", None, None
if load_dirpath:
load_path = DiffusionQuantCacheConfig(
smooth=os.path.join(load_dirpath, "smooth.pt"),
branch=os.path.join(load_dirpath, "branch.pt"),
wgts=os.path.join(load_dirpath, "wgts.pt"),
acts=os.path.join(load_dirpath, "acts.pt"),
)
load_model_path = os.path.join(load_dirpath, "model.pt")
if os.path.exists(load_model_path):
if config.enabled_wgts and config.wgts.enabled_low_rank:
if os.path.exists(load_path.branch):
load_model = True
else:
logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist")
load_model = False
else:
load_model = True
if load_model:
logger.info(f"* Loading model from {load_model_path}")
save_dirpath = "" # do not save the model if loading
else:
logger.warning(f"Model checkpoint {load_model_path} does not exist")
load_model = False
else:
load_model = False
if save_dirpath:
os.makedirs(save_dirpath, exist_ok=True)
save_path = DiffusionQuantCacheConfig(
smooth=os.path.join(save_dirpath, "smooth.pt"),
branch=os.path.join(save_dirpath, "branch.pt"),
wgts=os.path.join(save_dirpath, "wgts.pt"),
acts=os.path.join(save_dirpath, "acts.pt"),
)
else:
save_model = False
if quant and config.enabled_rotation:
logger.info("* Rotating model for quantization")
tools.logging.Formatter.indent_inc()
rotate_diffusion(model, config=config)
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# region smooth quantization
if quant and config.enabled_smooth:
logger.info("* Smoothing model for quantization")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.smooth):
load_from = load_path.smooth
elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
load_from = cache.path.smooth
if load_from:
logger.info(f"- Loading smooth scales from {load_from}")
smooth_cache = torch.load(load_from)
smooth_diffusion(model, config, smooth_cache=smooth_cache)
else:
logger.info("- Generating smooth scales")
smooth_cache = smooth_diffusion(model, config)
if cache and cache.path.smooth:
logger.info(f"- Saving smooth scales to {cache.path.smooth}")
os.makedirs(cache.dirpath.smooth, exist_ok=True)
torch.save(smooth_cache, cache.path.smooth)
load_from = cache.path.smooth
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking smooth scales to {save_path.smooth}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
else:
logger.info(f"- Saving smooth scales to {save_path.smooth}")
torch.save(smooth_cache, save_path.smooth)
del smooth_cache
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# endregion
# region collect original state dict
if config.needs_acts_quantizer_cache:
if load_path and os.path.exists(load_path.acts):
orig_state_dict = None
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
orig_state_dict = None
else:
orig_state_dict: dict[str, torch.Tensor] = {
name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
}
else:
orig_state_dict = None
# endregion
if load_model:
logger.info(f"* Loading model checkpoint from {load_model_path}")
load_diffusion_weights_state_dict(
model,
config,
state_dict=torch.load(load_model_path),
branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None,
)
gc.collect()
torch.cuda.empty_cache()
elif quant_wgts:
logger.info("* Quantizing weights")
tools.logging.Formatter.indent_inc()
quantizer_state_dict, quantizer_load_from = None, ""
if load_path and os.path.exists(load_path.wgts):
quantizer_load_from = load_path.wgts
elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
quantizer_load_from = cache.path.wgts
if quantizer_load_from:
logger.info(f"- Loading weight settings from {quantizer_load_from}")
quantizer_state_dict = torch.load(quantizer_load_from)
branch_state_dict, branch_load_from = None, ""
if load_path and os.path.exists(load_path.branch):
branch_load_from = load_path.branch
elif cache and cache.path.branch and os.path.exists(cache.path.branch):
branch_load_from = cache.path.branch
if branch_load_from:
logger.info(f"- Loading branch settings from {branch_load_from}")
branch_state_dict = torch.load(branch_load_from)
if not quantizer_load_from:
logger.info("- Generating weight settings")
if not branch_load_from:
logger.info("- Generating branch settings")
quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights(
model,
config,
quantizer_state_dict=quantizer_state_dict,
branch_state_dict=branch_state_dict,
return_with_scale_state_dict=bool(save_dirpath),
)
if not quantizer_load_from and cache and cache.dirpath.wgts:
logger.info(f"- Saving weight settings to {cache.path.wgts}")
os.makedirs(cache.dirpath.wgts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.wgts)
quantizer_load_from = cache.path.wgts
if not branch_load_from and cache and cache.dirpath.branch:
logger.info(f"- Saving branch settings to {cache.path.branch}")
os.makedirs(cache.dirpath.branch, exist_ok=True)
torch.save(branch_state_dict, cache.path.branch)
branch_load_from = cache.path.branch
if save_path:
if not copy_on_save and quantizer_load_from:
logger.info(f"- Linking weight settings to {save_path.wgts}")
os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts)
else:
logger.info(f"- Saving weight settings to {save_path.wgts}")
torch.save(quantizer_state_dict, save_path.wgts)
if not copy_on_save and branch_load_from:
logger.info(f"- Linking branch settings to {save_path.branch}")
os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch)
else:
logger.info(f"- Saving branch settings to {save_path.branch}")
torch.save(branch_state_dict, save_path.branch)
if save_model:
logger.info(f"- Saving model to {save_dirpath}")
torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
del quantizer_state_dict, branch_state_dict, scale_state_dict
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
if quant_acts:
logger.info(" * Quantizing activations")
tools.logging.Formatter.indent_inc()
if config.needs_acts_quantizer_cache:
load_from = ""
if load_path and os.path.exists(load_path.acts):
load_from = load_path.acts
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
load_from = cache.path.acts
if load_from:
logger.info(f"- Loading activation settings from {load_from}")
quantizer_state_dict = torch.load(load_from)
quantize_diffusion_activations(
model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict
)
else:
logger.info("- Generating activation settings")
quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
if cache and cache.dirpath.acts and quantizer_state_dict is not None:
logger.info(f"- Saving activation settings to {cache.path.acts}")
os.makedirs(cache.dirpath.acts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.acts)
load_from = cache.path.acts
if save_dirpath:
if not copy_on_save and load_from:
logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
else:
logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
torch.save(quantizer_state_dict, save_path.acts)
del quantizer_state_dict
else:
logger.info("- No need to generate/load activation quantizer settings")
quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
tools.logging.Formatter.indent_dec()
del orig_state_dict
gc.collect()
torch.cuda.empty_cache()
return model
def main(config: DiffusionPtqRunConfig, logging_level: int = tools.logging.DEBUG) -> DiffusionPipeline:
"""Post-training quantization of a diffusion model.
Args:
config (`DiffusionPtqRunConfig`):
The diffusion model post-training quantization configuration.
logging_level (`int`, *optional*, defaults to `logging.DEBUG`):
The logging level.
Returns:
`DiffusionPipeline`:
The diffusion pipeline with quantized model.
"""
config.output.lock()
config.dump(path=config.output.get_running_job_path("config.yaml"))
tools.logging.setup(path=config.output.get_running_job_path("run.log"), level=logging_level)
logger = tools.logging.getLogger(__name__)
logger.info("=== Configurations ===")
tools.logging.info(config.formatted_str(), logger=logger)
logger.info("=== Dumped Configurations ===")
tools.logging.info(pprint.pformat(config.dump(), indent=2, width=120), logger=logger)
logger.info("=== Output Directory ===")
logger.info(config.output.job_dirpath)
logger.info("=== Start Evaluating ===")
logger.info("* Building diffusion model pipeline")
tools.logging.Formatter.indent_inc()
pipeline = config.pipeline.build()
if "nf4" not in config.pipeline.name and "gguf" not in config.pipeline.name:
model = DiffusionModelStruct.construct(pipeline)
tools.logging.Formatter.indent_dec()
save_dirpath = os.path.join(config.output.running_job_dirpath, "cache")
if config.save_model:
if config.save_model.lower() in ("false", "none", "null", "nil"):
save_model = False
elif config.save_model.lower() in ("true", "default"):
save_dirpath, save_model = os.path.join(config.output.running_job_dirpath, "model"), True
else:
save_dirpath, save_model = config.save_model, True
else:
save_model = False
model = ptq(
model,
config.quant,
cache=config.cache,
load_dirpath=config.load_from,
save_dirpath=save_dirpath,
copy_on_save=config.copy_on_save,
save_model=save_model,
)
if config.pipeline.lora is not None:
load_from = ""
if config.quant.enabled_smooth:
if config.load_from and os.path.exists(os.path.join(config.load_from, "smooth.pt")):
load_from = os.path.join(config.load_from, "smooth.pt")
elif config.cache.path and os.path.exists(config.cache.path.smooth):
load_from = config.cache.path.smooth
elif os.path.exists(os.path.join(save_dirpath, "smooth.pt")):
load_from = os.path.join(save_dirpath, "smooth.pt")
logger.info(f"* Loading smooth scales from {load_from}")
config.pipeline.load_lora(pipeline, smooth_cache=torch.load(load_from) if load_from else None)
if config.text is not None and config.text.is_enabled():
for encoder_name, encoder, tokenizer in config.pipeline.extract_text_encoders(pipeline):
logger.info(f"* Post-training quantizing the text encoder {encoder_name}")
patch_attention(encoder)
patch_gemma_rms_norm(encoder)
save_dirpath = os.path.join(save_dirpath, "encoder")
setattr(
pipeline,
encoder_name,
llm_ptq(
encoder,
tokenizer,
config.text,
cache=config.text_cache,
load_dirpath=os.path.join(config.load_from, "encoder") if config.load_from else "",
save_dirpath=save_dirpath,
copy_on_save=config.copy_on_save,
save_model=save_model,
),
)
config.eval.gen_root = config.eval.gen_root.format(
output=config.output.running_dirpath, job=config.output.running_job_dirname
)
if config.skip_eval:
if not config.skip_gen:
logger.info("* Generating image")
tools.logging.Formatter.indent_inc()
config.eval.generate(pipeline, task=config.pipeline.task)
tools.logging.Formatter.indent_dec()
else:
logger.info(f"* Evaluating model {'(skipping generation)' if config.skip_gen else ''}")
tools.logging.Formatter.indent_inc()
results = config.eval.evaluate(pipeline, skip_gen=config.skip_gen, task=config.pipeline.task)
tools.logging.Formatter.indent_dec()
if results is not None:
logger.info(f"* Saving results to {config.output.job_dirpath}")
with open(config.output.get_running_job_path("results.json"), "w") as f:
json.dump(results, f, indent=2, sort_keys=True)
config.output.unlock()
if __name__ == "__main__":
config, _, unused_cfgs, unused_args, unknown_args = DiffusionPtqRunConfig.get_parser().parse_known_args()
assert isinstance(config, DiffusionPtqRunConfig)
if len(unused_cfgs) > 0:
tools.logging.warning(f"Unused configurations: {unused_cfgs}")
if unused_args is not None:
tools.logging.warning(f"Unused arguments: {unused_args}")
assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
try:
main(config, logging_level=tools.logging.DEBUG)
except Exception as e:
tools.logging.Formatter.indent_reset()
tools.logging.error("=== Error ===")
tools.logging.error(traceback.format_exc())
tools.logging.shutdown()
traceback.print_exc()
config.output.unlock(error=True)
raise e
================================================
FILE: deepcompressor/app/diffusion/quant/__init__.py
================================================
# -*- coding: utf-8 -*-
from .activation import quantize_diffusion_activations
from .config import DiffusionQuantCacheConfig, DiffusionQuantConfig
from .quantizer import DiffusionActivationQuantizer, DiffusionWeightQuantizer
from .rotate import rotate_diffusion
from .smooth import smooth_diffusion
from .weight import load_diffusion_weights_state_dict, quantize_diffusion_weights
================================================
FILE: deepcompressor/app/diffusion/quant/activation.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion model activation quantization calibration module."""
import gc
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.data.common import TensorType
from deepcompressor.utils import tools
from ..nn.struct import (
DiffusionAttentionStruct,
DiffusionBlockStruct,
DiffusionModelStruct,
DiffusionModuleStruct,
DiffusionTransformerBlockStruct,
)
from .config import DiffusionQuantConfig
from .quantizer import DiffusionActivationQuantizer
from .utils import get_needs_inputs_fn, get_needs_outputs_fn
__all__ = ["quantize_diffusion_activations"]
@torch.inference_mode()
def quantize_diffusion_block_activations( # noqa: C901
layer: DiffusionBlockStruct | DiffusionModuleStruct,
config: DiffusionQuantConfig,
quantizer_state_dict: dict[str, dict[str, torch.Tensor | float | None]],
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
orig_state_dict: dict[str, torch.Tensor] | None = None,
) -> dict[str, DiffusionActivationQuantizer]:
"""Quantize the activations of a diffusion model block.
Args:
layer (`DiffusionBlockStruct` or `DiffusionModuleStruct`):
The diffusion model block.
config (`DiffusionQuantConfig`):
The quantization configuration.
quantizer_state_dict (`dict[str, dict[str, torch.Tensor | float | None]]`):
The activation quantizers state dict cache.
layer_cache (`dict[str, IOTensorsCache]`, *optional*):
The layer cache.
layer_kwargs (`dict[str, Any]`, *optional*):
The layer keyword arguments.
orig_state_dict (`dict[str, torch.Tensor]`, *optional*):
The original state dictionary.
Returns:
`dict[str, DiffusionActivationQuantizer]`:
The activation quantizers.
"""
logger = tools.logging.getLogger(f"{__name__}.ActivationQuant")
logger.debug("- Quantizing layer %s", layer.name)
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
orig_state_dict = orig_state_dict or {}
args_caches: list[
tuple[
str, # key
TensorType,
list[nn.Linear], # modules
list[str], # module names
nn.Module, # eval module
str, # eval name
dict[str, tp.Any], # eval kwargs
list[tuple[nn.Parameter, torch.Tensor]], # original wgts
]
] = []
In, Out = TensorType.Inputs, TensorType.Outputs # noqa: F841
used_modules: set[nn.Module] = set()
for module_key, module_name, module, parent, field_name in layer.named_key_modules():
modules, orig_struct_wgts = None, {}
if field_name in ("k_proj", "v_proj", "add_q_proj", "add_v_proj"):
continue
if field_name in ("q_proj", "add_k_proj", "up_proj"):
grandparent = parent.parent
assert isinstance(grandparent, DiffusionTransformerBlockStruct)
if grandparent.parallel and parent.idx == 0:
if orig_state_dict:
orig_struct_wgts = {
proj_module: (proj_module.weight, orig_state_dict[f"{proj_name}.weight"])
for _, proj_name, proj_module, _, _ in grandparent.named_key_modules()
}
if field_name == "q_proj":
assert isinstance(parent, DiffusionAttentionStruct)
assert module_name == parent.q_proj_name
modules, module_names = parent.qkv_proj, parent.qkv_proj_names
if grandparent.ffn_struct is not None:
modules.append(grandparent.ffn_struct.up_proj)
module_names.append(grandparent.ffn_struct.up_proj_name)
elif field_name == "add_k_proj":
assert isinstance(parent, DiffusionAttentionStruct)
assert module_name == parent.add_k_proj_name
modules, module_names = parent.add_qkv_proj, parent.add_qkv_proj_names
if grandparent.add_ffn_struct is not None:
modules.append(grandparent.add_ffn_struct.up_proj)
module_names.append(grandparent.add_ffn_struct.up_proj_name)
else:
assert field_name == "up_proj"
if module in used_modules:
continue
assert module_name == grandparent.add_ffn_struct.up_proj_name
assert grandparent.attn_structs[0].is_self_attn()
eval_module, eval_name, eval_kwargs = grandparent.module, grandparent.name, layer_kwargs
elif isinstance(parent, DiffusionAttentionStruct):
eval_module, eval_name = parent.module, parent.name
eval_kwargs = parent.filter_kwargs(layer_kwargs) if layer_kwargs else {}
if orig_state_dict:
orig_struct_wgts = {
proj_module: (proj_module.weight, orig_state_dict[f"{proj_name}.weight"])
for _, proj_name, proj_module, _, _ in parent.named_key_modules()
}
if field_name == "q_proj":
assert module_name == parent.q_proj_name
modules, module_names = parent.qkv_proj, parent.qkv_proj_names
else:
assert field_name == "add_k_proj"
assert module_name == parent.add_k_proj_name
modules, module_names = parent.add_qkv_proj, parent.add_qkv_proj_names
if modules is None:
assert module not in used_modules
used_modules.add(module)
orig_wgts = [(module.weight, orig_state_dict[f"{module_name}.weight"])] if orig_state_dict else None
args_caches.append((module_key, In, [module], [module_name], module, module_name, None, orig_wgts))
else:
orig_wgts = []
for proj_module in modules:
assert proj_module not in used_modules
used_modules.add(proj_module)
if orig_state_dict:
orig_wgts.append(orig_struct_wgts.pop(proj_module))
orig_wgts.extend(orig_struct_wgts.values())
orig_wgts = None if not orig_wgts else orig_wgts
args_caches.append((module_key, In, modules, module_names, eval_module, eval_name, eval_kwargs, orig_wgts))
# endregion
quantizers: dict[str, DiffusionActivationQuantizer] = {}
tools.logging.Formatter.indent_inc()
for module_key, tensor_type, modules, module_names, eval_module, eval_name, eval_kwargs, orig_wgts in args_caches:
if isinstance(modules[0], nn.Linear):
channels_dim = -1
assert all(isinstance(m, nn.Linear) for m in modules)
elif isinstance(modules[0], nn.Conv2d):
channels_dim = 1
assert all(isinstance(m, nn.Conv2d) for m in modules)
else:
raise ValueError(f"Unknown module type: {type(modules[0])}")
if tensor_type == TensorType.Inputs:
cache_keys = [f"{name}.input" for name in module_names]
quantizer_config = config.unsigned_ipts if getattr(modules[0], "unsigned", False) else config.ipts
activations = layer_cache.get(module_names[0], IOTensorsCache()).inputs
else:
cache_keys = [f"{name}.output" for name in module_names]
quantizer_config = config.opts
activations = layer_cache.get(module_names[0], IOTensorsCache()).outputs
quantizer = DiffusionActivationQuantizer(
quantizer_config,
channels_dim=channels_dim,
develop_dtype=config.develop_dtype,
key=module_key,
tensor_type=tensor_type,
)
if quantizer.is_enabled():
if cache_keys[0] not in quantizer_state_dict:
logger.debug("- Calibrating %s", ", ".join(cache_keys))
quantizer.calibrate_dynamic_range(
modules=modules,
activations=activations,
eval_module=eval_module,
eval_inputs=layer_cache[eval_name].inputs if layer_cache else None,
eval_kwargs=eval_kwargs,
orig_weights=orig_wgts,
)
quantizer_state_dict[cache_keys[0]] = quantizer.state_dict()
gc.collect()
torch.cuda.empty_cache()
else:
quantizer.load_state_dict(quantizer_state_dict[cache_keys[0]], device=modules[0].weight.device)
for cache_key in cache_keys:
quantizers[cache_key] = quantizer
del quantizer
tools.logging.Formatter.indent_dec()
return quantizers
@torch.inference_mode()
def quantize_diffusion_activations(
model: nn.Module | DiffusionModelStruct,
config: DiffusionQuantConfig,
quantizer_state_dict: dict[str, dict[str, torch.Tensor | float | None]] | None = None,
orig_state_dict: dict[str, torch.Tensor] | None = None,
) -> dict[str, dict[str, torch.Tensor | float | None]]:
"""Quantize the activations of a diffusion model.
Args:
model (`nn.Module` or `DiffusionModelStruct`):
The diffusion model.
config (`DiffusionQuantConfig`):
The quantization configuration.
quantizer_state_dict (`dict[str, dict[str, torch.Tensor | float | None]]`, *optional*, defaults to `None`):
The activation quantizers state dict cache.
orig_state_dict (`dict[str, torch.Tensor]`, *optional*, defaults to `None`):
The original state dictionary.
Returns:
`dict[str, dict[str, torch.Tensor | float | None]]`:
The activation quantizers state dict cache.
"""
logger = tools.logging.getLogger(f"{__name__}.ActivationQuant")
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
quantizer_state_dict = quantizer_state_dict or {}
quantizers: dict[str, DiffusionActivationQuantizer] = {}
skip_pre_modules = all(key in config.ipts.skips for key in model.get_prev_module_keys())
skip_post_modules = all(key in config.ipts.skips for key in model.get_post_module_keys())
if not quantizer_state_dict and config.needs_acts_quantizer_cache:
with tools.logging.redirect_tqdm():
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader().iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model, config=config),
needs_outputs_fn=get_needs_outputs_fn(model, config=config),
skip_pre_modules=skip_pre_modules,
skip_post_modules=skip_post_modules,
),
desc="quantizing activations",
leave=False,
total=model.num_blocks + int(not skip_post_modules) + int(not skip_pre_modules) * 3,
dynamic_ncols=True,
):
block_quantizers = quantize_diffusion_block_activations(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
orig_state_dict=orig_state_dict,
)
quantizers.update(block_quantizers)
else:
for _, layer in model.get_named_layers(
skip_pre_modules=skip_pre_modules, skip_post_modules=skip_post_modules
).items():
block_quantizers = quantize_diffusion_block_activations(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
orig_state_dict=orig_state_dict,
)
quantizers.update(block_quantizers)
for _, module_name, module, _, _ in model.named_key_modules():
ipts_quantizer = quantizers.get(f"{module_name}.input", None)
opts_quantizer = quantizers.get(f"{module_name}.output", None)
needs_quant_ipts = ipts_quantizer is not None and ipts_quantizer.is_enabled()
needs_quant_opts = opts_quantizer is not None and opts_quantizer.is_enabled()
if needs_quant_ipts or needs_quant_opts:
logger.debug(
"- Quantizing %s (%s)",
module_name,
("inputs" if needs_quant_ipts else "")
+ (" and " if needs_quant_ipts and needs_quant_opts else "")
+ ("outputs" if needs_quant_opts else ""),
)
if needs_quant_ipts:
ipts_quantizer.as_hook(is_output=False).register(module)
if needs_quant_opts:
opts_quantizer.as_hook(is_output=True).register(module)
return quantizer_state_dict
================================================
FILE: deepcompressor/app/diffusion/quant/config.py
================================================
# -*- coding: utf-8 -*-
"""Quantization config."""
import os
from dataclasses import dataclass, field
import torch
from omniconfig import configclass
from deepcompressor.calib.config import (
QuantRotationConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
SmoothTransfomerConfig,
)
from deepcompressor.data.utils.dtype import eval_dtype
from deepcompressor.quantizer.config import QuantLowRankConfig
from deepcompressor.utils.common import num2str
from ..cache.config import DiffusionQuantCacheConfig
from ..dataset.calib import DiffusionCalibCacheLoaderConfig
from .quantizer.config import DiffusionModuleQuantizerConfig
__all__ = ["DiffusionQuantConfig"]
@configclass
@dataclass(kw_only=True)
class DiffusionQuantConfig(DiffusionModuleQuantizerConfig):
"""Diffusion model quantization configuration.
Args:
wgts (`DiffusionWeightQuantizerConfig`):
The weight quantization configuration.
ipts (`DiffusionActivationQuantizerConfig`):
The input activation quantization configuration.
opts (`DiffusionActivationQuantizerConfig`):
The output activation quantization configuration.
calib (`DiffusionCalibDatasetConfig`):
The calibration dataset configuration.
smooth (`TransfomerQuantSmoothConfig` or `None`, *optional*, defaults to `None`):
The smooth quantization configuration.
develop_dtype (`torch.dtype`, *optional*, defaults to `None`):
The development data type.
"""
calib: DiffusionCalibCacheLoaderConfig
rotation: QuantRotationConfig | None = None
smooth: SmoothTransfomerConfig | None = None
develop_dtype: torch.dtype = field(default_factory=lambda s=torch.float32: eval_dtype(s, with_quant_dtype=False))
def __post_init__(self) -> None: # noqa: C901
super().__post_init__()
if self.rotation is not None and not self.rotation.transforms:
self.rotation = None
if self.smooth is not None:
if not self.smooth.enabled_proj and not self.smooth.enabled_attn:
self.smooth = None
if self.enabled_smooth and self.smooth.enabled_proj and self.smooth.proj.allow_low_rank:
if self.enabled_wgts:
self.smooth.proj.allow_low_rank = self.wgts.enabled_low_rank
if self.smooth.proj.allow_low_rank:
self.smooth.proj.granularity = SearchBasedCalibGranularity.Layer
else:
self.smooth.proj.allow_low_rank = False
if self.enabled_ipts:
if self.ipts.enabled_calib_range and self.ipts.calib_range.granularity == SearchBasedCalibGranularity.Group:
self.ipts.calib_range.granularity = SearchBasedCalibGranularity.ChannelGroup
if self.ipts.static:
assert self.ipts.smallest_group_shape[0] == -1, "static quantization requires batch group size to be -1"
if self.enabled_opts:
if self.opts.enabled_calib_range and self.opts.calib_range.granularity == SearchBasedCalibGranularity.Group:
self.opts.calib_range.granularity = SearchBasedCalibGranularity.ChannelGroup
if self.opts.static:
assert self.opts.smallest_group_shape[0] == -1, "static quantization requires batch group size to be -1"
self.organize()
self.unsigned_ipts = self.ipts.for_unsigned()
@property
def enabled_rotation(self) -> bool:
"""Whether to enable rotation."""
return self.rotation is not None and bool(self.rotation.transforms)
@property
def enabled_smooth(self) -> bool:
"""Whether to enable smooth quantization."""
return self.smooth is not None
@property
def enabled_smooth_proj(self) -> bool:
"""Whether to enable smooth quantization for projections."""
return self.enabled_smooth and self.smooth.enabled_proj
@property
def enabled_smooth_attn(self) -> bool:
"""Whether to enable smooth quantization for attentions."""
return self.enabled_smooth and self.smooth.enabled_attn
@property
def needs_acts_quantizer_cache(self) -> bool:
"""Whether to cache the activations quantizer settings."""
if self.enabled_ipts and self.ipts.needs_calib_data:
return True
if self.enabled_opts and self.opts.needs_calib_data:
return True
return False
def generate_calib_dirname(self) -> str:
name = ""
if self.enabled_rotation:
name += "-rotate"
if self.rotation.random:
name += ".rnd"
if self.enabled_smooth:
name += "-smooth"
if self.enabled_smooth_proj:
name += ".proj"
if self.enabled_smooth_attn:
name += ".attn"
calib_name = super().generate_calib_dirname()
if calib_name:
name += f"-{calib_name}"
return name[1:] if name else name
def generate_cache_dirpath(
self, *, root: str, shift: bool, default_dtype: torch.dtype = torch.float16
) -> DiffusionQuantCacheConfig: # noqa: C901
"""Generate the cache paths for the module quantization configuration."""
quant_names = self.generate_dirnames(default_dtype=default_dtype)
if shift:
quant_names.append("shift")
if self.enabled_wgts and self.wgts.enabled_low_rank:
quant_names.extend(QuantLowRankConfig.generate_dirnames(self.wgts.low_rank, prefix="lowrank"))
if self.enabled_rotation:
quant_names.extend(self.rotation.generate_dirnames(prefix="rotate"))
smooth_dirpath = ""
if self.enabled_smooth:
quant_names.extend(self.smooth.generate_dirnames(prefix="smooth"))
smooth_dirpath = os.path.join("smooth", *quant_names)
branch_dirpath = ""
if self.enabled_wgts and self.wgts.enabled_low_rank:
quant_names.extend(self.wgts.low_rank.generate_dirnames(prefix="lowrank"))
branch_dirpath = os.path.join("branch", *quant_names)
wgts_dirpath = ""
if self.enabled_wgts and self.wgts.needs_calib_data:
quant_names.extend(self.wgts.calib_range.generate_dirnames(prefix="w.range"))
wgts_dirpath = os.path.join("wgts", *quant_names)
if self.enabled_wgts and self.wgts.enabled_gptq:
quant_names.extend(self.wgts.kernel_gptq.generate_dirnames(prefix="w.kernel"))
acts_dirpath = ""
if self.needs_acts_quantizer_cache:
if self.enabled_ipts and self.ipts.needs_calib_data:
quant_names.extend(self.ipts.calib_range.generate_dirnames(prefix="x.range"))
if self.enabled_opts and self.opts.needs_calib_data:
quant_names.extend(self.opts.calib_range.generate_dirnames(prefix="y.range"))
acts_dirpath = os.path.join("acts", *quant_names)
cache_dirpath = DiffusionQuantCacheConfig(
smooth=smooth_dirpath, branch=branch_dirpath, wgts=wgts_dirpath, acts=acts_dirpath
).simplify(type(self)._key_map)
cache_dirpath = cache_dirpath.add_parent_dirs(*self.calib.generate_dirnames())
cache_dirpath = cache_dirpath.add_parent_dirs(root, "diffusion", "cache", "quant")
return cache_dirpath
def generate_default_dirname(self) -> str: # noqa: C901
"""Generate output directory name for evaluating a large language model."""
key_map = type(self)._key_map
def simplify_skips(skips):
return set(
DiffusionQuantCacheConfig.simplify_path("skip.[{}]".format("+".join(skips)), key_map=key_map)[
6:-1
].split("+")
)
skip_name, y_skips, w_skips, x_skips = "", set(), set(), set()
if self.enabled_opts and self.opts.skips:
y_skips = simplify_skips(self.opts.skips)
if self.enabled_ipts and self.ipts.skips:
x_skips = simplify_skips(self.ipts.skips)
if self.enabled_wgts and self.wgts.skips:
w_skips = simplify_skips(self.wgts.skips)
skips_map = {}
if y_skips or x_skips or w_skips:
skip_name = "-skip"
skip_name_list: list[tuple[str, set]] = []
if y_skips:
skip_name_list.append(("y", y_skips))
if x_skips:
skip_name_list.append(("x", x_skips))
if w_skips:
skip_name_list.append(("w", w_skips))
# sort the keys by the number of elements in the set
skip_name_list = sorted(skip_name_list, key=lambda x: (len(x[1]), x[0]), reverse=False)
skips_map = {k: v for k, v in skip_name_list} # noqa: C416
skip_name_map: dict[str, set] = {}
skip_0, skip_0_names = skip_name_list[0]
skip_name_map[skip_0] = skip_0_names
if len(skip_name_list) > 1:
skip_1, skip_1_names = skip_name_list[1]
if skip_1_names.issuperset(skip_0_names):
skip_1_names = skip_1_names - skip_0_names
skip_1_names.add(f"[{skip_0}]")
skip_name_map[skip_1] = skip_1_names
if len(skip_name_list) > 2:
skip_2, skip_2_names = skip_name_list[2]
if skip_2_names.issuperset(skip_name_list[1][1]): # skip_1_names may be modified
skip_2_names = skip_2_names - skip_name_list[1][1]
skip_2_names.add(f"[{skip_1}]")
if skip_2_names.issuperset(skip_0_names):
skip_2_names = skip_2_names - skip_0_names
skip_2_names.add(f"[{skip_0}]")
skip_name_map[skip_2] = skip_2_names
if "y" in skip_name_map:
skip_name += f".y.[{'+'.join(sorted(skip_name_map['y']))}]"
if "x" in skip_name_map:
skip_name += f".x.[{'+'.join(sorted(skip_name_map['x']))}]"
if "w" in skip_name_map:
skip_name += f".w.[{'+'.join(sorted(skip_name_map['w']))}]"
del skip_name_list, skip_name_map
extra_name = ""
if self.enabled_extra_wgts:
extra_name = "-extra.[{}]".format("+".join(sorted(simplify_skips(self.extra_wgts.includes))))
lowrank_name = ""
if self.enabled_wgts and self.wgts.enabled_low_rank:
lowrank_name = f"-low.r{num2str(self.wgts.low_rank.rank)}"
if self.wgts.low_rank.num_iters > 1:
lowrank_name += f".i{num2str(self.wgts.low_rank.num_iters)}"
if self.wgts.low_rank.early_stop:
lowrank_name += ".e"
if self.wgts.low_rank.exclusive:
lowrank_name += ".s"
if self.wgts.low_rank.compensate:
lowrank_name += ".c"
if self.wgts.low_rank.objective != SearchBasedCalibObjective.OutputsError:
lowrank_name += f".{self.wgts.low_rank.objective.name}"
if self.wgts.low_rank.skips:
lowrank_skips = simplify_skips(self.wgts.low_rank.skips)
if "w" in skips_map and lowrank_skips.issuperset(skips_map["w"]):
lowrank_skips = lowrank_skips - skips_map["w"]
lowrank_skips.add("[w]")
lowrank_name += ".skip.[{}]".format("+".join(sorted(lowrank_skips)))
rotation_name = ""
if self.enabled_rotation:
rotation_name = "-rot"
if self.rotation.random:
rotation_name += ".rnd"
rotation_name += ".[{}]".format("+".join(sorted(simplify_skips(self.rotation.transforms))))
smooth_name = ""
if self.enabled_smooth:
smooth_name = "-smth"
if self.smooth.enabled_proj:
smooth_name += ".proj"
if self.smooth.proj.granularity != SearchBasedCalibGranularity.Layer:
smooth_name += f".{self.smooth.proj.granularity.name}"
if self.smooth.proj.strategy != SearchBasedCalibStrategy.Manual:
smooth_name += f".{self.smooth.proj.strategy.name}"
if self.smooth.proj.alpha <= 0:
smooth_name += f".a{num2str(self.smooth.proj.alpha)}"
if self.smooth.proj.beta <= 0:
smooth_name += f".b{num2str(self.smooth.proj.beta)}"
else:
smooth_name += f".a{num2str(self.smooth.proj.alpha)}"
smooth_name += f".b{num2str(self.smooth.proj.beta)}"
xspan_eq_wspan = True
for xspan, wspan in self.smooth.proj.spans:
if xspan != wspan:
xspan_eq_wspan = False
break
if xspan_eq_wspan:
smooth_name += ".[{}]".format("+".join(xspan.name for xspan, _ in self.smooth.proj.spans))
else:
smooth_name += ".[{}]".format(
"+".join(f"x.{xspan.name}.w.{wspan.name}" for xspan, wspan in self.smooth.proj.spans)
)
if self.smooth.proj.allow_low_rank:
smooth_name += ".lr"
if not self.smooth.proj.allow_b_quant or not self.smooth.proj.allow_a_quant:
smooth_name += ".no.["
if not self.smooth.proj.allow_a_quant:
smooth_name += "a+"
if not self.smooth.proj.allow_b_quant:
smooth_name += "b+"
smooth_name = smooth_name[:-1] + "]"
if self.smooth.proj.skips:
smooth_skips = simplify_skips(self.smooth.proj.skips)
if "w" in skips_map and smooth_skips.issuperset(skips_map["w"]):
smooth_skips = smooth_skips - skips_map["w"]
smooth_skips.add("[w]")
smooth_name += ".skip.[{}]".format("+".join(sorted(smooth_skips)))
if self.smooth.enabled_attn:
smooth_name += ".yx"
if self.smooth.attn.granularity != SearchBasedCalibGranularity.Layer:
smooth_name += f".{self.smooth.attn.granularity.name}"
if self.smooth.attn.strategy != SearchBasedCalibStrategy.Manual:
smooth_name += f".{self.smooth.attn.strategy.name}"
if self.smooth.attn.alpha <= 0:
smooth_name += f".a{num2str(self.smooth.attn.alpha)}"
if self.smooth.attn.beta <= 0:
smooth_name += f".b{num2str(self.smooth.attn.beta)}"
else:
smooth_name += f".a{num2str(self.smooth.attn.alpha)}"
smooth_name += f".b{num2str(self.smooth.attn.beta)}"
xspan_eq_yspan = True
for xspan, yspan in self.smooth.attn.spans:
if xspan != yspan:
xspan_eq_yspan = False
break
if xspan_eq_yspan:
smooth_name += ".[{}]".format("+".join(xspan.name for xspan, _ in self.smooth.attn.spans))
else:
smooth_name += ".[{}]".format(
"+".join(f"x.{xspan.name}.y.{yspan.name}" for xspan, yspan in self.smooth.attn.spans)
)
gptq_name = ""
if self.enabled_wgts and self.wgts.kernel_gptq is not None:
gptq_name = "-gptq"
if self.wgts.kernel_gptq.skips:
gptq_skips = simplify_skips(self.wgts.kernel_gptq.skips)
if "w" in skips_map and gptq_skips.issuperset(skips_map["w"]):
gptq_skips = gptq_skips - skips_map["w"]
gptq_skips.add("[w]")
gptq_name += ".skip.[{}]".format("+".join(sorted(gptq_skips)))
wrange_name = ""
if (
self.enabled_wgts
and self.wgts.enabled_calib_range
and (self.wgts.calib_range.needs_search or self.wgts.calib_range.ratio != 1)
):
wrange_name = "-w.range"
if self.wgts.calib_range.needs_search:
if self.wgts.calib_range.granularity != SearchBasedCalibGranularity.Group:
wrange_name += f".{self.wgts.calib_range.granularity.name}"
if self.wgts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
wrange_name += f".{self.wgts.calib_range.objective.name}"
if self.wgts.calib_range.degree != 2:
wrange_name += f".d{num2str(self.wgts.calib_range.degree)}"
wrange_name += f".[{num2str(self.wgts.calib_range.max_shrink)}"
wrange_name += f".{num2str(self.wgts.calib_range.max_expand)}"
wrange_name += f".g{self.wgts.calib_range.num_grids}]"
else:
wrange_name += f".r{num2str(self.wgts.calib_range.ratio)}"
if self.wgts.calib_range.skips:
wrange_skips = simplify_skips(self.wgts.calib_range.skips)
if "w" in skips_map and wrange_skips.issuperset(skips_map["w"]):
wrange_skips = wrange_skips - skips_map["w"]
wrange_skips.add("[w]")
wrange_name += ".skip.[{}]".format("+".join(sorted(wrange_skips)))
xrange_name = ""
if (
self.enabled_ipts
and self.ipts.enabled_calib_range
and (self.ipts.calib_range.needs_search or self.ipts.calib_range.ratio != 1)
):
xrange_name = "-x.range"
if self.ipts.calib_range.needs_search:
if self.ipts.calib_range.granularity != SearchBasedCalibGranularity.Group:
xrange_name += f".{self.ipts.calib_range.granularity.name}"
if self.ipts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
xrange_name += f".{self.ipts.calib_range.objective.name}"
if self.ipts.calib_range.degree != 2:
xrange_name += f".d{num2str(self.ipts.calib_range.degree)}"
xrange_name += f".[{num2str(self.ipts.calib_range.max_shrink)}"
xrange_name += f".{num2str(self.ipts.calib_range.max_expand)}"
xrange_name += f".g{self.ipts.calib_range.num_grids}]"
else:
xrange_name += f".r{num2str(self.ipts.calib_range.ratio)}"
if self.ipts.calib_range.skips:
xrange_skips = simplify_skips(self.ipts.calib_range.skips)
if "x" in skips_map and xrange_skips.issuperset(skips_map["x"]):
xrange_skips = xrange_skips - skips_map["x"]
xrange_skips.add("[x]")
xrange_name += ".skip.[{}]".format("+".join(sorted(xrange_skips)))
yrange_name = ""
if (
self.enabled_opts
and self.opts.enabled_calib_range
and (self.opts.calib_range.needs_search or self.opts.calib_range.ratio != 1)
):
yrange_name = "-y.range"
if self.opts.calib_range.needs_search:
if self.opts.calib_range.granularity != SearchBasedCalibGranularity.Group:
yrange_name += f".{self.opts.calib_range.granularity.name}"
if self.opts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
yrange_name += f".{self.opts.calib_range.objective.name}"
if self.opts.calib_range.degree != 2:
yrange_name += f".d{num2str(self.opts.calib_range.degree)}"
yrange_name += f".[{num2str(self.opts.calib_range.max_shrink)}"
yrange_name += f".{num2str(self.opts.calib_range.max_expand)}"
yrange_name += f".g{self.opts.calib_range.num_grids}]"
else:
yrange_name += f".r{num2str(self.opts.calib_range.ratio)}"
if self.opts.calib_range.skips:
yrange_skips = simplify_skips(self.opts.calib_range.skips)
if "y" in skips_map and yrange_skips.issuperset(skips_map["y"]):
yrange_skips = yrange_skips - skips_map["y"]
yrange_skips.add("[y]")
yrange_name += ".skip.[{}]".format("+".join(sorted(yrange_skips)))
name = (
skip_name
+ extra_name
+ lowrank_name
+ rotation_name
+ smooth_name
+ gptq_name
+ wrange_name
+ xrange_name
+ yrange_name
)
name = name[1:] if name else "default"
name += f"-{self.calib.generate_dirnames()[0]}"
return name
@classmethod
def set_key_map(cls, key_map: dict[str, set[str]]) -> None:
"""Set the key map for the language model quantization configuration.
Args:
key_map (dict[str, set[str]]): The key map.
"""
cls._key_map = key_map
def organize(self) -> dict[str, bool]: # noqa: C901
"""Organize the flags for the diffusion model quantization configuration.
Returns:
dict[str, bool]: The organized flags.
"""
key_map = type(self)._key_map
wgts_skip_set, ipts_skip_set, opts_skip_set = set(), set(), set()
if self.wgts is not None:
wgts_skips = []
for skip in self.wgts.skips:
wgts_skips.extend(list(key_map[skip]))
wgts_skip_set = set(wgts_skips)
self.wgts.skips = sorted(wgts_skip_set)
if self.wgts.kernel_gptq is not None:
wgts_kernel_gptq_skips = []
for skip in self.wgts.kernel_gptq.skips:
wgts_kernel_gptq_skips.extend(list(key_map[skip]))
self.wgts.kernel_gptq.skips = sorted(set(wgts_kernel_gptq_skips) - wgts_skip_set)
if self.wgts.low_rank is not None:
wgts_low_rank_skips = []
for skip in self.wgts.low_rank.skips:
wgts_low_rank_skips.extend(list(key_map[skip]))
self.wgts.low_rank.skips = sorted(set(wgts_low_rank_skips) - wgts_skip_set)
if self.wgts.calib_range is not None:
wgts_calib_range_skips = []
for skip in self.wgts.calib_range.skips:
wgts_calib_range_skips.extend(list(key_map[skip]))
self.wgts.calib_range.skips = sorted(set(wgts_calib_range_skips) - wgts_skip_set)
if self.extra_wgts is not None:
extra_includes = []
for include in self.extra_wgts.includes:
extra_includes.extend(list(key_map[include]))
extra_includes_set = set(extra_includes) - wgts_skip_set
self.extra_wgts.includes = sorted(extra_includes_set)
if not self.extra_wgts.is_enabled():
self.extra_wgts = None
if self.ipts is not None:
ipts_skips = []
for skip in self.ipts.skips:
ipts_skips.extend(list(key_map[skip]))
ipts_skip_set = set(ipts_skips)
self.ipts.skips = sorted(ipts_skip_set)
if self.ipts.calib_range is not None:
ipts_calib_range_skips = []
for skip in self.ipts.calib_range.skips:
ipts_calib_range_skips.extend(list(key_map[skip]))
self.ipts.calib_range.skips = sorted(set(ipts_calib_range_skips) - ipts_skip_set)
if self.opts is not None:
opts_skips = []
for skip in self.opts.skips:
opts_skips.extend(list(key_map[skip]))
opts_skip_set = set(opts_skips)
self.opts.skips = sorted(opts_skip_set)
if self.opts.calib_range is not None:
opts_calib_range_skips = []
for skip in self.opts.calib_range.skips:
opts_calib_range_skips.extend(list(key_map[skip]))
self.opts.calib_range.skips = sorted(set(opts_calib_range_skips) - opts_skip_set)
if self.smooth is not None and self.smooth.proj is not None:
smooth_proj_skips = []
for skip in self.smooth.proj.skips:
smooth_proj_skips.extend(list(key_map[skip]))
self.smooth.proj.skips = sorted(set(smooth_proj_skips) - (wgts_skip_set & ipts_skip_set))
if self.rotation is not None:
rotation_transforms = []
for transform in self.rotation.transforms:
rotation_transforms.extend(list(key_map[transform]))
self.rotation.transforms = sorted(set(rotation_transforms))
================================================
FILE: deepcompressor/app/diffusion/quant/quantizer/__init__.py
================================================
# -*- coding: utf-8 -*-
from .config import DiffusionModuleQuantizerConfig
from .quantizer import DiffusionActivationQuantizer, DiffusionWeightQuantizer
================================================
FILE: deepcompressor/app/diffusion/quant/quantizer/config.py
================================================
# -*- coding: utf-8 -*-
"""Quantizatizer config."""
from dataclasses import dataclass, field
import torch
from omniconfig import configclass
from deepcompressor.calib.config import SkipBasedDynamicRangeCalibConfig, SkipBasedQuantLowRankCalibConfig
from deepcompressor.data.dtype import QuantDataType
from deepcompressor.quantizer.config import QuantizerConfig
from deepcompressor.quantizer.kernel import QuantGptqConfig
from deepcompressor.utils.config import EnableConfig, IncludeBasedConfig, SkipBasedConfig
__all__ = [
"DiffusionQuantizerConfig",
"DiffusionWeightQuantizerConfig",
"DiffusionActivationQuantizerConfig",
"DiffusionModuleQuantizerConfig",
]
@configclass
@dataclass
class DiffusionGPTQConfig(SkipBasedConfig, QuantGptqConfig):
"""Configuration for GPTQ quantization.
Args:
damp_percentage (`float`, *optional*, defaults to `0.01`):
The percentage of damping.
block_size (`int`, *optional*, defaults to `128`):
The block size of the GPTQ quantization.
num_inv_tries (`int`, *optional*, defaults to `200`):
The number of tries for the inverse.
hessian_block_size (`int`, *optional*, defaults to `-1`):
The block size when calculing the Hessian.
skips: list[str] = field(default_factory=list)
"""
pass
@configclass
@dataclass
class DiffusionQuantizerConfig(QuantizerConfig):
"""Diffusion model quantizer configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
static (`bool`, *optional*, defaults to `False`):
Whether to use static quantization.
kernel_gptq (`DiffusionGPTQConfig` or `None`, *optional*, defaults to `None`):
The gptq quantization configuration.
low_rank (`SkipBasedQuantLowRankCalibConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch calibration configuration.
calib_range (`DynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The quantizatizer dynamic range calibration configuration.
"""
static: bool = False
kernel_gptq: DiffusionGPTQConfig | None = None
low_rank: SkipBasedQuantLowRankCalibConfig | None = None
calib_range: SkipBasedDynamicRangeCalibConfig | None = None
def __post_init__(self) -> None:
super().__post_init__()
if self.quant_dtype is None:
self.static = False
self.kernel_gptq = None
self.low_rank = None
self.calib_range = None
if self.kernel_gptq is not None and not self.kernel_gptq.is_enabled():
self.kernel_gptq = None
if self.static and self.calib_range is None:
self.calib_range = SkipBasedDynamicRangeCalibConfig()
if self.low_rank is not None and not self.low_rank.is_enabled():
self.low_rank = None
@property
def enabled_gptq(self) -> bool:
"""Whether quantization kernel calibration is enabled."""
return self.kernel_gptq is not None and self.kernel_gptq.is_enabled()
@property
def enabled_low_rank(self) -> bool:
"""Whether quantization SVD calibration is enabled."""
return self.low_rank is not None and self.low_rank.is_enabled()
@property
def enabled_calib_range(self) -> bool:
"""Whether quantization dynamic range calibration is enabled."""
return self.calib_range is not None
def generate_calib_dirname(self) -> str:
"""Generate the name for quantization calibration.
Returns:
str: The name.
"""
name = ""
if self.static:
name += ".static"
if self.enabled_gptq:
name += ".gptq"
if self.enabled_low_rank:
name += ".lowrank"
if self.enabled_calib_range and (self.calib_range.needs_search or self.calib_range.ratio != 1):
name += ".range"
return name[1:] if name else ""
@configclass
@dataclass
class SkipBasedDiffusionQuantizerConfig(SkipBasedConfig, DiffusionQuantizerConfig):
"""Diffusion model quantizer configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
skips (`[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
static (`bool`, *optional*, defaults to `False`):
Whether to use static quantization.
kernel_gptq (`DiffusionGPTQConfig` or `None`, *optional*, defaults to `None`):
The gptq quantization configuration.
low_rank (`SkipBasedQuantLowRankCalibConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch calibration configuration.
calib_range (`DynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The quantizatizer dynamic range calibration configuration.
"""
def __post_init__(self) -> None:
super().__post_init__()
if self.quant_dtype is None:
self.skips.clear()
@configclass
@dataclass
class DiffusionWeightQuantizerConfig(SkipBasedDiffusionQuantizerConfig):
"""Diffusion model weight quantizer configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
skips (`list[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
low_rank (`SkipBasedQuantLowRankCalibConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch calibration configuration.
calib_range (`DynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The quantizatizer dynamic range calibration configuration.
"""
static: bool = field(init=False, default=True)
@property
def needs_calib_data(self) -> bool:
return self.enabled_calib_range and self.calib_range.needs_search
@configclass
@dataclass
class DiffusionActivationQuantizerConfig(SkipBasedDiffusionQuantizerConfig):
"""Diffusion model activation quantizer configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
skips (`list[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
static (`bool`, *optional*, defaults to `False`):
Whether to use static quantization.
calib_range (`DynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The quantizatizer dynamic range calibration configuration.
allow_unsigned (`bool`, *optional*, defaults to `False`):
Whether to allow unsigned data type for activation quantization.
"""
kernel_gptq: None = field(init=False, default=None)
low_rank: None = field(init=False, default=None)
allow_unsigned: bool = False
@property
def needs_calib_data(self) -> bool:
return self.enabled_calib_range and (self.calib_range.needs_search or self.static)
def generate_dirnames(
self,
*,
prefix: str = "",
shape: torch.Size | tuple[int, ...] = (1024, 1024, 16, 16),
default_dtype: torch.dtype = torch.float16,
**kwargs,
) -> list[str]:
"""Get the directory names of the quantization configuration.
Args:
prefix (`str`, *optional*, defaults to `""`):
The prefix for the directory names.
shape (`torch.Size` or `tuple[int, ...]`, *optional*, defaults to `(1024, 1024, 16, 16)`):
The shape of the tensor to be quantized.
default_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The dtype of the tensor to be quantized.
Returns:
`list[str]`:
The directory names of the quantization configuration.
- The number of effective bits.
- The name of the quantization data type.
- The name of the group shapes.
- The name of the modules to skip.
"""
names = super().generate_dirnames(prefix=prefix, shape=shape, default_dtype=default_dtype)
if self.allow_unsigned:
names[1] += ".u"
return names
def for_unsigned(self) -> "DiffusionActivationQuantizerConfig":
"""get the quantizer configuration for unsigned activations.
Returns:
`DiffusionActivationQuantizerConfig`:
The quantizer configuration for unsigned activations.
"""
if isinstance(self.dtype, QuantDataType) and self.allow_unsigned:
return DiffusionActivationQuantizerConfig(
dtype=self.dtype.to_unsigned(),
zero_point=self.zero_point,
group_shapes=self.group_shapes,
scale_dtypes=self.scale_dtypes,
skips=self.skips,
static=self.static,
calib_range=self.calib_range,
allow_unsigned=self.allow_unsigned,
)
else:
return self
@configclass
@dataclass
class DiffusionExtraWeightQuantizerConfig(IncludeBasedConfig, DiffusionQuantizerConfig):
"""Diffusion model extra weight quantizer configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
includes (`list[str]`, *optional*, defaults to `[]`):
The keys of the modules to include.
low_rank (`SkipBasedQuantLowRankCalibConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch calibration configuration.
calib_range (`DynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The quantizatizer dynamic range calibration configuration.
"""
static: bool = field(init=False, default=True)
kernel_gptq: DiffusionGPTQConfig | None = field(init=False, default=None)
low_rank: SkipBasedQuantLowRankCalibConfig | None = field(init=False, default=None)
calib_range: SkipBasedDynamicRangeCalibConfig | None = field(init=False, default=None)
@property
def needs_calib_data(self) -> bool:
return self.enabled_calib_range and self.calib_range.needs_search
@configclass
@dataclass(kw_only=True)
class DiffusionModuleQuantizerConfig(EnableConfig):
"""Diffusion model module quantizer configuration.
Args:
wgts (`DiffusionWeightQuantizerConfig`):
The weight quantization configuration.
ipts (`DiffusionActivationQuantizerConfig`):
The input activation quantization configuration.
opts (`DiffusionActivationQuantizerConfig`):
The output activation quantization configuration.
"""
wgts: DiffusionWeightQuantizerConfig
ipts: DiffusionActivationQuantizerConfig
opts: DiffusionActivationQuantizerConfig
extra_wgts: DiffusionExtraWeightQuantizerConfig | None = None
unsigned_ipts: DiffusionActivationQuantizerConfig = field(init=False)
def is_enabled(self):
return self.enabled_wgts or self.enabled_ipts or self.enabled_opts
@property
def enabled_wgts(self) -> bool:
"""Whether to enable weight quantization."""
return self.wgts is not None and self.wgts.is_enabled()
@property
def enabled_ipts(self) -> bool:
"""Whether to enable activation quantization."""
return self.ipts is not None and self.ipts.is_enabled()
@property
def enabled_opts(self) -> bool:
"""Whether to enable activation quantization."""
return self.opts is not None and self.opts.is_enabled()
@property
def enabled_extra_wgts(self) -> bool:
"""Whether to enable extra weight quantization."""
return self.extra_wgts is not None and self.extra_wgts.is_enabled()
def __post_init__(self) -> None:
if self.enabled_opts:
raise NotImplementedError("Output activation quantization is not supported yet.")
if self.wgts.is_enabled() and self.extra_wgts is not None:
self.extra_wgts.includes = list(filter(lambda key: key not in self.wgts.skips, self.extra_wgts.includes))
if self.extra_wgts.is_enabled():
self.extra_wgts.kernel_gptq = self.wgts.kernel_gptq
self.extra_wgts.low_rank = self.wgts.low_rank
self.extra_wgts.calib_range = self.wgts.calib_range
else:
self.extra_wgts = None
else:
self.extra_wgts = None
def generate_dirnames(
self,
*,
prefix: str = "",
shape: torch.Size | tuple[int, ...] = (1024, 1024, 16, 16),
default_dtype: torch.dtype = torch.float16,
**kwargs,
) -> list[str]:
"""Get the directory names of the quantization configuration.
Args:
prefix (`str`, *optional*, defaults to `""`):
The prefix for the directory names.
shape (`torch.Size` or `tuple[int, ...]`, *optional*, defaults to `(1024, 1024, 16, 16)`):
The shape of the tensor to be quantized.
default_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The dtype of the tensor to be quantized.
Returns:
`list[str]`:
The directory names of the quantization configuration.
- The number of effective bits.
- The name of the quantization data type.
- The name of the group shapes.
- The name of the modules to skip.
"""
wgts_names = self.wgts.generate_dirnames(prefix="w", shape=shape, default_dtype=default_dtype)
ipts_names = self.ipts.generate_dirnames(prefix="x", shape=shape, default_dtype=default_dtype)
opts_names = self.opts.generate_dirnames(prefix="y", shape=shape, default_dtype=default_dtype)
names = [
f"{wgts_name}-{ipts_name}-{opts_name}"
for wgts_name, ipts_name, opts_name in zip(wgts_names, ipts_names, opts_names, strict=True)
]
if self.extra_wgts is not None:
extra_wgts_names = self.extra_wgts.generate_dirnames(prefix="w", shape=shape, default_dtype=default_dtype)
names = [f"{name}-{extra_wgts_name}" for name, extra_wgts_name in zip(names, extra_wgts_names, strict=True)]
if prefix:
names = [f"{prefix}.[{name}]" for name in names]
return names
def generate_calib_dirname(self) -> str:
"""Generate the name for quantization calibration.
Returns:
`str`:
The name.
"""
name = ""
if self.enabled_wgts:
calib_name = self.wgts.generate_calib_dirname()
if calib_name:
name += f"-w.{calib_name}"
if self.enabled_ipts:
calib_name = self.ipts.generate_calib_dirname()
if calib_name:
name += f"-x.{calib_name}"
if self.enabled_opts:
calib_name = self.opts.generate_calib_dirname()
if calib_name:
name += f"-y.{calib_name}"
return name[1:] if name else name
================================================
FILE: deepcompressor/app/diffusion/quant/quantizer/quantizer.py
================================================
# -*- coding: utf-8 -*-
"""Tensor Quantizer module."""
import typing as tp
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from deepcompressor.calib.config import SkipBasedQuantLowRankCalibConfig
from deepcompressor.calib.lowrank import LowRankBranch, QuantLowRankCalibrator
from deepcompressor.calib.range import calibrate_dynamic_range
from deepcompressor.data.cache import TensorsCache
from deepcompressor.data.common import TensorType
from deepcompressor.data.range import DynamicRange
from deepcompressor.quantizer.processor import Quantizer
from .config import (
DiffusionActivationQuantizerConfig,
DiffusionGPTQConfig,
DiffusionQuantizerConfig,
DiffusionWeightQuantizerConfig,
)
__all__ = ["DiffusionQuantizer", "DiffusionWeightQuantizer", "DiffusionActivationQuantizer"]
@dataclass
class DiffusionQuantizer(Quantizer):
"""Denoising model quantizer class.
Args:
config (`DiffusionQuantizerConfig` or `None`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
tensor_type (`TensorType`, *optional*, defaults to `TensorType.Weights`):
The type of the tensor to quantize.
channels_dim (`int` or `None`, *optional*, defaults to `None`):
The dimension of channels.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
kernel (`DiffusionGPTQConfig` or `None`, *optional*, defaults to `MISSING`):
The GPTQ kernel configuration.
If not provided (i.e., `MISSING`), the GPTQ configuration from the `config` will be used.
low_rank (`QuantLowRankConfig` or `None`, *optional*, defaults to `MISSING`):
The quantization low-rank branch configuration.
If not provided (i.e., `MISSING`), the low-rank branch configuration from the `config` will be used.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
"""
config: DiffusionQuantizerConfig
kernel: DiffusionGPTQConfig | None = field(init=False)
low_rank: SkipBasedQuantLowRankCalibConfig | None = field(init=False)
tensor_type: TensorType = TensorType.Weights
def __post_init__(self) -> None:
self.kernel = self.config.kernel_gptq
self.low_rank = self.config.low_rank
def calibrate_dynamic_range(
self,
modules: tp.Sequence[nn.Module],
activations: TensorsCache,
weights: tp.Sequence[nn.Parameter] = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_weights: tp.Sequence[tuple[nn.Parameter, torch.Tensor]] | None = None,
orig_activations: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> tp.Sequence[DynamicRange] | None:
"""Calibrate the dynamic range.
Args:
modules (`Sequence[nn.Module]`):
The modules to calibrate.
activations (`TensorsCache`):
The inputs cache if the tensor type is not outputs, or the outputs cache if the tensor type is outputs.
weights (`Sequence[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights to calibrate.
If not provided, the weights of the modules will be used.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs for evaluation.
If not provided, the `activations` cache will be used.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
If not provided, the module to calibrate will be used.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
orig_weights (`Sequence[tuple[nn.Parameter, torch.Tensor]]` or `None`, *optional*, defaults to `None`):
The original weights.
orig_activations (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original activations.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original evaluation inputs.
Returns:
`Sequence[DynamicRange]` or `None`:
The dynamic ranges of each quantization step.
"""
if (
not self.is_enabled()
or self.config.calib_range is None
or not self.config.calib_range.is_enabled_for(self.key)
):
self.dynamic_range = None
else:
self.dynamic_range = calibrate_dynamic_range(
tensor_type=self.tensor_type,
config=self.config.calib_range,
static=self.config.static,
quantizer=self,
modules=modules,
activations=activations,
weights=weights,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
orig_weights=orig_weights,
orig_activations=orig_activations,
orig_eval_inputs=orig_eval_inputs,
)
return self.dynamic_range
@dataclass
class DiffusionWeightQuantizer(DiffusionQuantizer):
"""Diffusion model weight quantizer class.
Args:
Args:
config (`DiffusionWeightQuantizerConfig` or `None`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
kernel (`DiffusionGPTQConfig` or `None`, *optional*, defaults to `MISSING`):
The GPTQ kernel configuration.
If not provided (i.e., `MISSING`), the GPTQ configuration from the `config` will be used.
low_rank (`QuantLowRankConfig` or `None`, *optional*, defaults to `MISSING`):
The quantization low-rank branch configuration.
If not provided (i.e., `MISSING`), the low-rank branch configuration from the `config` will be used.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
"""
config: DiffusionWeightQuantizerConfig
channels_dim: None = field(init=False, default=None)
tensor_type: TensorType = field(init=False, default=TensorType.Weights)
def calibrate_dynamic_range(
self,
module: nn.Module,
inputs: TensorsCache,
weight: nn.Parameter | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_inputs: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> tp.Sequence[DynamicRange] | None:
"""Calibrate the dynamic range.
Args:
module (`nn.Module`):
The module to calibrate.
inputs (`TensorsCache`):
The inputs cache.
weight (`nn.Parameter` or `None`, *optional*, defaults to `None`):
The weight parameter to calibrate.
If not provided, the weight of the `module` will be used.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs for evaluation.
If not provided, the `activations` cache will be used.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
If not provided, the module to calibrate will be used.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
orig_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original inputs.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original evaluation inputs.
Returns:
`Sequence[DynamicRange]` or `None`:
The dynamic ranges of each quantization step.
"""
return super().calibrate_dynamic_range(
modules=[module],
weights=[weight] if weight is not None else [module.weight],
activations=inputs,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
orig_activations=orig_inputs,
orig_eval_inputs=orig_eval_inputs,
)
def calibrate_low_rank(
self,
input_quantizer: "DiffusionActivationQuantizer",
modules: tp.Sequence[nn.Module],
inputs: TensorsCache,
weights: tp.Sequence[nn.Parameter] = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_inputs: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> LowRankBranch:
"""Calibrate the quantization low-rank branch."""
if weights is None:
weights = [module.weight for module in modules]
return QuantLowRankCalibrator(
config=self.low_rank,
w_quantizer=self,
x_quantizer=input_quantizer,
develop_dtype=self.develop_dtype,
).calibrate(
x_wgts=weights,
x_acts=inputs,
x_mods=modules,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
orig_x_acts=orig_inputs,
orig_eval_inputs=orig_eval_inputs,
)
@dataclass
class DiffusionActivationQuantizer(DiffusionQuantizer):
"""Diffusion model activation quantizer class.
Args:
config (`DiffusionActivationQuantizerConfig` or `None`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
tensor_type (`TensorType`, *optional*, defaults to `TensorType.Inputs`):
The type of the tensor to quantize.
channels_dim (`int` or `None`, *optional*, defaults to `None`):
The dimension of channels.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
develop_dtype (torch.dtype, optional): The develop dtype. Defaults to ``torch.float32``.
"""
config: DiffusionActivationQuantizerConfig
tensor_type: TensorType = TensorType.Inputs
def __post_init__(self) -> None:
super().__post_init__()
assert self.tensor_type != TensorType.Weights, "The tensor type cannot be weights."
assert isinstance(self.channels_dim, int), "The channels dimension must be provided."
================================================
FILE: deepcompressor/app/diffusion/quant/rotate.py
================================================
# -*- coding: utf-8 -*-
"""Large Language Model Rotation module."""
import gc
import torch
from deepcompressor.calib.rotate import (
get_rotation_matrix,
hadamard_in_channels,
rotate_in_channels,
rotate_out_channels,
)
from deepcompressor.utils import tools
from ..nn.struct import DiffusionModelStruct
from .config import DiffusionQuantConfig
__all__ = ["rotate_diffusion"]
@torch.inference_mode()
def rotate_diffusion( # noqa: C901
model: DiffusionModelStruct, /, config: DiffusionQuantConfig
):
"""Rotate the weights of the diffusion model.
Args:
model (`PreTrainedModel` or `LlmStruct`):
Model to be rotated.
config (`QuantRotationConfig`):
Rotation configuration.
"""
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
devices: dict[str, torch.device] = {}
dtypes: dict[str, torch.dtype] = {}
linears: dict[str, torch.nn.Linear] = {}
size: float = 0
for n, m in model.module.named_modules():
if isinstance(m, torch.nn.Linear):
devices[n] = m.weight.device
dtypes[n] = m.weight.dtype
linears[n] = m
size += m.weight.numel() / 1e9
for linear in linears.values():
linear.to(dtype=torch.float32, device="cpu" if size > 30 else None)
logger = tools.logging.getLogger(f"{__name__}.Rotate")
head_rotation = None
for transformer_block in model.iter_transformer_block_structs():
logger.debug(f"- Rotating {transformer_block.name}")
tools.logging.Formatter.indent_inc()
for attn in transformer_block.iter_attention_structs():
if attn.qkv_proj_key in config.rotation.transforms:
if attn.qkv_proj_key not in config.wgts.skips or attn.qkv_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {attn.name}.qkv_proj (in)")
hadamard_in_channels(
attn.qkv_proj, dtype=dtypes[attn.q_proj_name], device=devices[attn.q_proj_name]
)
if not attn.is_self_attn() and attn.add_qkv_proj_key in config.rotation.transforms:
if attn.add_qkv_proj_key not in config.wgts.skips or attn.add_qkv_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {attn.name}.add_qkv_proj (in)")
hadamard_in_channels(
attn.add_qkv_proj, dtype=dtypes[attn.add_k_proj_name], device=devices[attn.add_k_proj_name]
)
if attn.out_proj_key in config.rotation.transforms or attn.add_out_proj_key in config.rotation.transforms:
if (
attn.out_proj_key not in config.wgts.skips
or attn.out_proj_key not in config.ipts.skips
or attn.add_out_proj_key not in config.wgts.skips
or attn.add_out_proj_key not in config.ipts.skips
):
if head_rotation is None:
head_rotation = get_rotation_matrix(
attn.config.num_head_channels, random=config.rotation.random
)
if attn.v_proj is not None:
logger.debug(f"- Rotating {attn.v_proj_name} (out)")
rotate_out_channels(attn.v_proj.weight, rotation=head_rotation, bias=attn.v_proj.bias)
if attn.add_v_proj is not None:
logger.debug(f"- Rotating {attn.add_v_proj_name} (out)")
rotate_out_channels(attn.add_v_proj.weight, rotation=head_rotation, bias=attn.add_v_proj.bias)
if attn.o_proj is not None:
logger.debug(f"- Rotating {attn.o_proj_name} (in)")
rotate_in_channels(attn.o_proj.weight, rotation=head_rotation)
if attn.add_o_proj is not None:
logger.debug(f"- Rotating {attn.add_o_proj_name} (in)")
rotate_in_channels(attn.add_o_proj.weight, rotation=head_rotation)
gc.collect()
torch.cuda.empty_cache()
ffn, add_ffn = transformer_block.ffn_struct, transformer_block.add_ffn_struct
if ffn.up_proj_key in config.rotation.transforms:
if ffn.up_proj_key not in config.wgts.skips or ffn.up_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {ffn.up_proj_name} (in)")
hadamard_in_channels(ffn.up_projs, dtype=dtypes[ffn.up_proj_name], device=devices[ffn.up_proj_name])
if add_ffn is not None and add_ffn.up_proj_key in config.rotation.transforms:
if add_ffn.up_proj_key not in config.wgts.skips or add_ffn.up_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {add_ffn.up_proj_name} (in)")
hadamard_in_channels(
add_ffn.up_projs, dtype=dtypes[add_ffn.up_proj_name], device=devices[add_ffn.up_proj_name]
)
if ffn.down_proj_key in config.rotation.transforms:
if ffn.down_proj_key not in config.wgts.skips or ffn.down_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {ffn.down_proj_name} (in)")
hadamard_in_channels(
ffn.down_projs, dtype=dtypes[ffn.down_proj_name], device=devices[ffn.down_proj_name]
)
if add_ffn is not None and add_ffn.down_proj_key in config.rotation.transforms:
if add_ffn.down_proj_key not in config.wgts.skips or add_ffn.down_proj_key not in config.ipts.skips:
logger.debug(f"- Hadamard transform on {add_ffn.down_proj_name} (in)")
hadamard_in_channels(
add_ffn.down_projs, dtype=dtypes[add_ffn.down_proj_name], device=devices[add_ffn.down_proj_name]
)
gc.collect()
torch.cuda.empty_cache()
tools.logging.Formatter.indent_dec()
for n, m in linears.items():
m.to(device=devices[n], dtype=dtypes[n])
================================================
FILE: deepcompressor/app/diffusion/quant/smooth.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion smooth quantization module."""
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from deepcompressor.calib.smooth import ActivationSmoother, smooth_linear_modules
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.quantizer import Quantizer
from deepcompressor.utils import tools
from deepcompressor.utils.hooks import KeyedInputPackager
from ..nn.struct import (
DiffusionAttentionStruct,
DiffusionBlockStruct,
DiffusionFeedForwardStruct,
DiffusionModelStruct,
DiffusionTransformerBlockStruct,
)
from .config import DiffusionQuantConfig
from .utils import get_needs_inputs_fn, wrap_joint_attn
__all__ = ["smooth_diffusion"]
@torch.inference_mode()
def smooth_diffusion_attention(
attn: DiffusionAttentionStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
# attention qk
if config.smooth.enabled_attn:
logger.debug("- %s.k", attn.name)
raise NotImplementedError("Not implemented yet")
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_qkv_proj(
attn: DiffusionAttentionStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
# region qkv projection
module_key = attn.qkv_proj_key
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- %s.qkv_proj", attn.name)
prevs = None
if config.smooth.proj.fuse_when_possible and attn.parent.norm_type.startswith("layer_norm"):
if not hasattr(attn.parent.module, "pos_embed") or attn.parent.module.pos_embed is None:
prevs = attn.parent.pre_attn_norms[attn.idx]
assert isinstance(prevs, nn.LayerNorm)
cache_key = attn.q_proj_name
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
attn.qkv_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.q_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[attn.name].inputs if block_cache else None,
eval_module=attn,
eval_kwargs=attn.filter_kwargs(block_kwargs),
develop_dtype=config.develop_dtype,
)
if prevs is None:
# we need to register forward pre hook to smooth inputs
if attn.module.group_norm is None and attn.module.spatial_norm is None:
ActivationSmoother(
smooth_cache[cache_key],
channels_dim=-1,
input_packager=KeyedInputPackager(attn.module, [0]),
).as_hook().register(attn.module)
else:
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(attn.qkv_proj)
for m in attn.qkv_proj:
m.in_smooth_cache_key = cache_key
# endregion
if attn.is_self_attn():
return smooth_cache
# region additional qkv projection
module_key = attn.add_qkv_proj_key
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
needs_quant = needs_quant and attn.add_k_proj is not None
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- %s add_qkv_proj", attn.name)
prevs = None
pre_attn_add_norm = attn.parent.pre_attn_add_norms[attn.idx]
if isinstance(pre_attn_add_norm, nn.LayerNorm) and config.smooth.proj.fuse_when_possible:
prevs = pre_attn_add_norm
cache_key = attn.add_k_proj_name
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
attn.add_qkv_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.add_k_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[attn.name].inputs if block_cache else None,
eval_module=wrap_joint_attn(attn, indexes=1) if attn.is_joint_attn() else attn,
eval_kwargs=attn.filter_kwargs(block_kwargs),
develop_dtype=config.develop_dtype,
)
if prevs is None:
# we need to register forward pre hook to smooth inputs
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(attn.add_qkv_proj)
for m in attn.add_qkv_proj:
m.in_smooth_cache_key = cache_key
# endregion
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_out_proj( # noqa: C901
attn: DiffusionAttentionStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
module_keys = []
for module_key in (attn.out_proj_key, attn.add_out_proj_key) if attn.is_joint_attn() else (attn.out_proj_key,):
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
module_keys.append(module_key)
if not module_keys:
return smooth_cache
exclusive = False
if config.enabled_wgts and config.wgts.enabled_low_rank:
exclusive = config.wgts.low_rank.exclusive
config.wgts.low_rank.exclusive = True
fuse_smooth = not attn.config.linear_attn and config.smooth.proj.fuse_when_possible
prevs = [attn.v_proj, attn.add_v_proj] if fuse_smooth else None
if len(module_keys) == 1 and module_keys[0] == attn.out_proj_key:
logger.debug("- %s.out_proj", attn.name)
module_key = attn.out_proj_key
cache_key = attn.o_proj_name
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
attn.o_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.o_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[attn.o_proj_name].inputs if block_cache else None,
eval_module=attn.o_proj,
extra_modules=[attn.add_o_proj] if attn.is_joint_attn() else None,
develop_dtype=config.develop_dtype,
)
elif len(module_keys) == 1 and module_keys[0] == attn.add_out_proj_key:
assert attn.is_joint_attn()
logger.debug("- %s.add_out_proj", attn.name)
module_key = attn.add_out_proj_key
cache_key = attn.add_o_proj_name
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
attn.add_o_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.add_o_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[attn.add_o_proj_name].inputs if block_cache else None,
eval_module=attn.add_o_proj,
extra_modules=[attn.o_proj],
develop_dtype=config.develop_dtype,
)
else:
assert attn.is_joint_attn()
logger.debug("- %s.out_proj + %s.add_out_proj", attn.name, attn.name)
module_key = attn.out_proj_key
cache_key = attn.o_proj_name
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
[attn.o_proj, attn.add_o_proj],
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.o_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[attn.name].inputs if block_cache else None,
eval_module=wrap_joint_attn(attn, indexes=(0, 1)),
eval_kwargs=attn.filter_kwargs(block_kwargs),
develop_dtype=config.develop_dtype,
)
if config.enabled_wgts and config.wgts.enabled_low_rank:
config.wgts.low_rank.exclusive = exclusive
if fuse_smooth:
for prev in prevs:
if prev is not None:
prev.out_smooth_cache_key = cache_key
else:
for o_proj in [attn.o_proj, attn.add_o_proj]:
if o_proj is not None:
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(o_proj)
attn.o_proj.in_smooth_cache_key = cache_key
if attn.add_o_proj is not None:
attn.add_o_proj.in_smooth_cache_key = cache_key
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_up_proj(
pre_ffn_norm: nn.Module,
ffn: DiffusionFeedForwardStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
) -> dict[str, torch.Tensor]:
assert len(ffn.up_projs) == 1
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
# ffn up projection
module_key = ffn.up_proj_key
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- %s.up_proj", ffn.name)
prevs = None
if config.smooth.proj.fuse_when_possible and isinstance(pre_ffn_norm, nn.LayerNorm):
if ffn.parent.norm_type in ["ada_norm", "layer_norm"]:
prevs = pre_ffn_norm
cache_key = ffn.up_proj_name
channels_dim = -1 if isinstance(ffn.down_proj, nn.Linear) else 1
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
prevs,
ffn.up_projs,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key),
inputs=block_cache[ffn.up_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[ffn.up_proj_name].inputs if block_cache else None,
eval_module=ffn.up_proj,
develop_dtype=config.develop_dtype,
)
if prevs is None:
ActivationSmoother(smooth_cache[cache_key], channels_dim=channels_dim).as_hook().register(ffn.up_proj)
for proj in ffn.up_projs:
proj.in_smooth_cache_key = cache_key
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_down_proj(
ffn: DiffusionFeedForwardStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
) -> dict[str, torch.Tensor]:
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
# ffn down projection
module_key = ffn.down_proj_key.upper()
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- %s.down_proj", ffn.name)
cache_key = ffn.down_proj_name
config_ipts = config.unsigned_ipts if getattr(ffn.down_proj, "unsigned", False) else config.ipts
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
channels_dim = -1 if isinstance(ffn.down_proj, nn.Linear) else 1
smooth_cache[cache_key] = smooth_linear_modules(
None,
ffn.down_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config_ipts, channels_dim=channels_dim, key=module_key),
inputs=block_cache[ffn.down_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[ffn.down_proj_name].inputs if block_cache else None,
eval_module=ffn.down_proj,
develop_dtype=config.develop_dtype,
)
ffn.down_proj.in_smooth_cache_key = cache_key
ActivationSmoother(smooth_cache[cache_key], channels_dim=channels_dim).as_hook().register(ffn.down_proj)
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_parallel_qkv_up_proj(
block: DiffusionTransformerBlockStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
assert block.parallel
assert len(block.ffn_struct.up_projs) == 1
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
# region qkv proj + up proj
attn, ffn = block.attn_structs[0], block.ffn_struct
module_key = attn.qkv_proj_key
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- %s.qkv_proj + %s.up_proj", attn.name, ffn.name)
cache_key = attn.q_proj_name
modules = attn.qkv_proj + ffn.up_projs
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
None,
modules,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.q_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[block.name].inputs if block_cache else None,
eval_module=block,
eval_kwargs=block_kwargs,
splits=[len(attn.qkv_proj)],
develop_dtype=config.develop_dtype,
)
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(modules)
for m in modules:
m.in_smooth_cache_key = cache_key
# endregion
# region add qkv proj + add up proj
if attn.is_self_attn():
if block.add_ffn_struct is not None:
smooth_cache = smooth_diffusion_up_proj(
pre_ffn_norm=block.pre_add_ffn_norm,
ffn=block.add_ffn_struct,
config=config,
smooth_cache=smooth_cache,
block_cache=block_cache,
)
return smooth_cache
module_key = attn.add_qkv_proj_key
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
add_ffn = block.add_ffn_struct
cache_key = attn.add_k_proj_name
modules = attn.add_qkv_proj
if add_ffn is None:
logger.debug("- %s.add_qkv_proj", attn.name)
else:
logger.debug("- %s.add_qkv_proj + %s.up_proj", attn.name, add_ffn.name)
modules = modules + add_ffn.up_projs
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
None,
modules,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key, low_rank=config.wgts.low_rank),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=module_key),
inputs=block_cache[attn.add_k_proj_name].inputs if block_cache else None,
eval_inputs=block_cache[block.name].inputs if block_cache else None,
eval_module=block,
eval_kwargs=block_kwargs,
develop_dtype=config.develop_dtype,
)
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(modules)
for m in modules:
m.in_smooth_cache_key = cache_key
# endregion
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_sequential_transformer_block(
block: DiffusionTransformerBlockStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
assert not block.parallel
for attn in block.attn_structs:
smooth_cache = smooth_diffusion_attention(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
smooth_cache = smooth_diffusion_qkv_proj(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
smooth_cache = smooth_diffusion_out_proj(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
if block.ffn_struct is not None:
smooth_cache = smooth_diffusion_up_proj(
pre_ffn_norm=block.pre_ffn_norm,
ffn=block.ffn_struct,
config=config,
smooth_cache=smooth_cache,
block_cache=block_cache,
)
smooth_cache = smooth_diffusion_down_proj(
ffn=block.ffn_struct, config=config, smooth_cache=smooth_cache, block_cache=block_cache
)
if block.add_ffn_struct is not None:
smooth_cache = smooth_diffusion_up_proj(
pre_ffn_norm=block.pre_add_ffn_norm,
ffn=block.add_ffn_struct,
config=config,
smooth_cache=smooth_cache,
block_cache=block_cache,
)
smooth_cache = smooth_diffusion_down_proj(
ffn=block.add_ffn_struct, config=config, smooth_cache=smooth_cache, block_cache=block_cache
)
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_parallel_transformer_block(
block: DiffusionTransformerBlockStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
block_cache: dict[str, IOTensorsCache] | None = None,
block_kwargs: dict[str, tp.Any] | None = None,
) -> dict[str, torch.Tensor]:
assert block.parallel
assert block.ffn_struct is not None
for attn in block.attn_structs:
smooth_cache = smooth_diffusion_attention(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
if attn.idx == 0:
smooth_cache = smooth_diffusion_parallel_qkv_up_proj(
block=block,
config=config,
smooth_cache=smooth_cache,
block_cache=block_cache,
block_kwargs=block_kwargs,
)
else:
smooth_cache = smooth_diffusion_qkv_proj(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
smooth_cache = smooth_diffusion_out_proj(
attn=attn, config=config, smooth_cache=smooth_cache, block_cache=block_cache, block_kwargs=block_kwargs
)
smooth_cache = smooth_diffusion_down_proj(
ffn=block.ffn_struct, config=config, smooth_cache=smooth_cache, block_cache=block_cache
)
if block.add_ffn_struct is not None:
smooth_cache = smooth_diffusion_down_proj(
ffn=block.add_ffn_struct, config=config, smooth_cache=smooth_cache, block_cache=block_cache
)
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_module(
module_key: str,
module_name: str,
module: nn.Linear | nn.Conv2d,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
layer_cache: dict[str, IOTensorsCache] | None = None,
) -> dict[str, torch.Tensor]:
assert isinstance(module, (nn.Linear, nn.Conv2d))
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
logger.debug("- Smoothing Module %s", module_name)
tools.logging.Formatter.indent_inc()
logger.debug("- %s", module_name)
cache_key = module_name
channels_dim = -1 if isinstance(module, nn.Linear) else 1
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
smooth_cache[cache_key] = smooth_linear_modules(
None,
module,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config_wgts, key=module_key),
input_quantizer=Quantizer(config.ipts, channels_dim=channels_dim, key=module_key),
inputs=layer_cache[module_name].inputs if layer_cache else None,
eval_inputs=layer_cache[module_name].inputs if layer_cache else None,
eval_module=module,
develop_dtype=config.develop_dtype,
)
ActivationSmoother(smooth_cache[cache_key], channels_dim=channels_dim).as_hook().register(module)
module.in_smooth_cache_key = cache_key
tools.logging.Formatter.indent_dec()
else:
logger.debug("- Skipping Module %s", module_name)
return smooth_cache
@torch.inference_mode()
def smooth_diffusion_layer(
layer: DiffusionBlockStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor],
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
) -> None:
"""Smooth a single diffusion model block.
Args:
layer (`DiffusionBlockStruct`):
The diffusion block.
config (`DiffusionQuantConfig`):
The quantization configuration.
smooth_cache (`dict[str, torch.Tensor]`):
The smoothing scales cache.
layer_cache (`dict[str, IOTensorsCache]`, *optional*):
The layer cache.
layer_kwargs (`dict[str, tp.Any]`, *optional*):
The layer keyword arguments.
"""
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
logger.debug("- Smoothing Diffusion Block %s", layer.name)
tools.logging.Formatter.indent_inc()
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
# We skip resnets since we currently cannot scale the Swish function
visited: set[str] = set()
for module_key, module_name, module, parent, _ in layer.named_key_modules():
if isinstance(parent, (DiffusionAttentionStruct, DiffusionFeedForwardStruct)):
block = parent.parent
assert isinstance(block, DiffusionTransformerBlockStruct)
if block.name not in visited:
logger.debug("- Smoothing Transformer Block %s", block.name)
visited.add(block.name)
tools.logging.Formatter.indent_inc()
if block.parallel:
smooth_cache = smooth_diffusion_parallel_transformer_block(
block=block,
config=config,
smooth_cache=smooth_cache,
block_cache=layer_cache,
block_kwargs=layer_kwargs,
)
else:
smooth_cache = smooth_diffusion_sequential_transformer_block(
block=block,
config=config,
smooth_cache=smooth_cache,
block_cache=layer_cache,
block_kwargs=layer_kwargs,
)
tools.logging.Formatter.indent_dec()
elif isinstance(module, (nn.Linear, nn.Conv2d)):
smooth_cache = smooth_diffusion_module(
module_key=module_key,
module_name=module_name,
module=module,
config=config,
smooth_cache=smooth_cache,
layer_cache=layer_cache,
)
else:
needs_quant = config.enabled_wgts and config.wgts.is_enabled_for(module_key)
needs_quant = needs_quant or (config.enabled_ipts and config.ipts.is_enabled_for(module_key))
if needs_quant and config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(module_key):
raise NotImplementedError(f"Module {module_name} is not supported for smoothing")
logger.debug("- Skipping Module %s", module_name)
tools.logging.Formatter.indent_dec()
@torch.inference_mode()
def smooth_diffusion(
model: nn.Module | DiffusionModelStruct,
config: DiffusionQuantConfig,
smooth_cache: dict[str, torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
"""Smooth the diffusion model.
Args:
model (`nn.Module` or `DiffusionModelStruct`):
The diffusion model.
config (`DiffusionQuantConfig`):
The quantization configuration.
smooth_cache (`dict[str, torch.Tensor]`, *optional*):
The smoothing scales cache.
Returns:
`dict[str, torch.Tensor]`:
The smoothing scales cache.
"""
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
smooth_cache = smooth_cache or {}
if config.smooth.enabled_proj:
if smooth_cache:
assert smooth_cache.get("proj.fuse_when_possible", True) == config.smooth.proj.fuse_when_possible
if config.smooth.enabled_attn:
if smooth_cache:
assert smooth_cache.get("attn.fuse_when_possible", True) == config.smooth.attn.fuse_when_possible
if not smooth_cache:
with tools.logging.redirect_tqdm():
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader().iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model, config),
skip_pre_modules=True,
skip_post_modules=True,
),
desc="smoothing",
leave=False,
total=model.num_blocks,
dynamic_ncols=True,
):
smooth_diffusion_layer(
layer=layer,
config=config,
smooth_cache=smooth_cache,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
)
else:
for layer in model.block_structs:
smooth_diffusion_layer(layer=layer, config=config, smooth_cache=smooth_cache)
if config.smooth.enabled_proj:
smooth_cache.setdefault("proj.fuse_when_possible", config.smooth.proj.fuse_when_possible)
if config.smooth.enabled_attn:
smooth_cache.setdefault("attn.fuse_when_possible", config.smooth.attn.fuse_when_possible)
return smooth_cache
================================================
FILE: deepcompressor/app/diffusion/quant/utils.py
================================================
import typing as tp
import torch
import torch.nn as nn
from ..nn.struct import DiffusionAttentionStruct, DiffusionFeedForwardStruct, DiffusionModelStruct
from .config import DiffusionQuantConfig
__all__ = ["get_needs_inputs_fn", "get_needs_outputs_fn", "wrap_joint_attn"]
def wrap_joint_attn(attn: nn.Module, /, *, indexes: int | tuple[int, ...] = 1) -> tp.Callable:
if isinstance(indexes, int):
def eval(*args, **kwargs) -> torch.Tensor:
return attn(*args, **kwargs)[indexes]
else:
def eval(*args, **kwargs) -> tuple[torch.Tensor, ...]:
tensors = attn(*args, **kwargs)
result = torch.concat([tensors[i] for i in indexes], dim=-2)
return result
return eval
def get_needs_inputs_fn(
model: DiffusionModelStruct, config: DiffusionQuantConfig
) -> tp.Callable[[str, nn.Module], bool]:
"""Get function that checks whether the module needs to cache inputs.
Args:
model (`DiffusionModelStruct`):
The diffused model.
config (`DiffusionQuantConfig`):
The quantization configuration.
Returns:
`Callable[[str, nn.Module], bool]`:
The function that checks whether the module needs to cache inputs.
"""
needs_inputs_names = set()
for module_key, module_name, _, parent, field_name in model.named_key_modules():
if (config.enabled_wgts and config.wgts.is_enabled_for(module_key)) or (
config.enabled_ipts and config.ipts.is_enabled_for(module_key)
):
if isinstance(parent, DiffusionAttentionStruct):
if field_name.endswith("o_proj"):
needs_inputs_names.add(module_name)
elif field_name in ("q_proj", "k_proj", "v_proj"):
needs_inputs_names.add(parent.q_proj_name)
if parent.parent.parallel and parent.idx == 0:
needs_inputs_names.add(parent.parent.name)
else:
needs_inputs_names.add(parent.name)
elif field_name in ("add_q_proj", "add_k_proj", "add_v_proj"):
needs_inputs_names.add(parent.add_k_proj_name)
if parent.parent.parallel and parent.idx == 0:
needs_inputs_names.add(parent.parent.name)
else:
needs_inputs_names.add(parent.name)
else:
raise RuntimeError(f"Unknown field name: {field_name}")
elif isinstance(parent, DiffusionFeedForwardStruct):
if field_name == "up_proj":
needs_inputs_names.update(parent.up_proj_names[: parent.config.num_experts])
elif field_name == "down_proj":
needs_inputs_names.update(parent.down_proj_names[: parent.config.num_experts])
else:
raise RuntimeError(f"Unknown field name: {field_name}")
else:
needs_inputs_names.add(module_name)
def needs_inputs(name: str, module: nn.Module) -> bool:
return name in needs_inputs_names
return needs_inputs
def get_needs_outputs_fn(
model: DiffusionModelStruct, config: DiffusionQuantConfig
) -> tp.Callable[[str, nn.Module], bool]:
"""Get function that checks whether the module needs to cache outputs.
Args:
model (`DiffusionModelStruct`):
The diffused model.
config (`DiffusionQuantConfig`):
The quantization configuration.
Returns:
`Callable[[str, nn.Module], bool]`:
The function that checks whether the module needs to cache outputs.
"""
# TODO: Implement the function that checks whether the module needs to cache outputs.
def needs_outputs(name: str, module: nn.Module) -> bool:
return False
return needs_outputs
================================================
FILE: deepcompressor/app/diffusion/quant/weight.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion model weight quantization calibration module."""
import gc
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.data.zero import ZeroPointDomain
from deepcompressor.nn.patch.lowrank import LowRankBranch
from deepcompressor.utils import tools
from ..nn.struct import DiffusionAttentionStruct, DiffusionBlockStruct, DiffusionModelStruct, DiffusionModuleStruct
from .config import DiffusionQuantConfig
from .quantizer import DiffusionActivationQuantizer, DiffusionWeightQuantizer
from .utils import get_needs_inputs_fn, wrap_joint_attn
__all__ = ["quantize_diffusion_weights", "load_diffusion_weights_state_dict"]
@torch.inference_mode()
def calibrate_diffusion_block_low_rank_branch( # noqa: C901
layer: DiffusionModuleStruct | DiffusionBlockStruct,
config: DiffusionQuantConfig,
branch_state_dict: dict[str, dict[str, torch.Tensor]],
layer_cache: dict[str, IOTensorsCache] = None,
layer_kwargs: dict[str, tp.Any] = None,
) -> None:
"""Calibrate low-rank branches for a block of a diffusion model.
Args:
layer (`DiffusionModuleStruct` or `DiffusionBlockStruct`):
The block to calibrate.
config (`DiffusionQuantConfig`):
The quantization configuration.
branch_state_dict (`dict[str, dict[str, torch.Tensor]]`):
The state dict of the low-rank branches.
layer_cache (`dict[str, IOTensorsCache]`, *optional*, defaults to `None`):
The cache of the layer.
layer_kwargs (`dict[str, tp.Any]`, *optional*, defaults to `None`):
The keyword arguments for the layer.
"""
assert config.wgts.low_rank is not None
logger = tools.logging.getLogger(f"{__name__}.WeightQuantSVD")
logger.debug("- Calibrating low-rank branches of block %s", layer.name)
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
for module_key, module_name, module, parent, field_name in layer.named_key_modules():
modules, module_names = [module], [module_name]
if not config.wgts.low_rank.exclusive:
if field_name.endswith(("q_proj", "k_proj", "v_proj")):
assert isinstance(parent, DiffusionAttentionStruct)
if parent.is_self_attn():
if field_name == "q_proj":
modules, module_names = parent.qkv_proj, parent.qkv_proj_names
else:
continue
elif parent.is_cross_attn():
if field_name == "add_k_proj":
modules.append(parent.add_v_proj)
module_names.append(parent.add_v_proj_name)
elif field_name != "q_proj":
continue
else:
assert parent.is_joint_attn()
if field_name == "q_proj":
modules, module_names = parent.qkv_proj, parent.qkv_proj_names
elif field_name == "add_k_proj":
modules, module_names = parent.add_qkv_proj, parent.add_qkv_proj_names
else:
continue
if field_name.endswith(("q_proj", "k_proj")):
assert isinstance(parent, DiffusionAttentionStruct)
if parent.parent.parallel and parent.idx == 0:
eval_module = parent.parent.module
eval_name = parent.parent.name
eval_kwargs = layer_kwargs
else:
eval_module = parent.module
eval_name = parent.name
eval_kwargs = parent.filter_kwargs(layer_kwargs)
if parent.is_joint_attn() and "add_" in field_name:
eval_module = wrap_joint_attn(eval_module, indexes=1)
else:
eval_module, eval_name, eval_kwargs = module, module_name, None
if isinstance(modules[0], nn.Linear):
assert all(isinstance(m, nn.Linear) for m in modules)
channels_dim = -1
else:
assert all(isinstance(m, nn.Conv2d) for m in modules)
channels_dim = 1
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key)
if quantizer.is_enabled() and quantizer.is_enabled_low_rank():
if isinstance(module, nn.Conv2d):
assert module.weight.shape[2:].numel()
else:
assert isinstance(module, nn.Linear)
if module_name not in branch_state_dict:
logger.debug("- Calibrating low-rank branch for %s", ", ".join(module_names))
tools.logging.Formatter.indent_inc()
branch_state_dict[module_name] = quantizer.calibrate_low_rank(
input_quantizer=DiffusionActivationQuantizer(
config.ipts, key=module_key, channels_dim=channels_dim
),
modules=modules,
inputs=layer_cache[module_name].inputs if layer_cache else None,
eval_inputs=layer_cache[eval_name].inputs if layer_cache else None,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
).state_dict()
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
shared_branch = LowRankBranch(
in_features=module.weight.shape[1],
out_features=sum(m.weight.shape[0] for m in modules),
rank=config.wgts.low_rank.rank,
)
shared_branch.to(device=module.weight.device, dtype=module.weight.dtype)
shared_branch.load_state_dict(branch_state_dict[module_name])
logger.debug(" + Adding low-rank branches to %s", ", ".join(module_names))
if len(modules) > 1:
oc_idx = 0
for module in modules:
branch = LowRankBranch(
in_features=module.weight.shape[1],
out_features=module.weight.shape[0],
rank=config.wgts.low_rank.rank,
)
branch.a = shared_branch.a
branch.b.to(dtype=module.weight.dtype, device=module.weight.device)
branch.b.weight.copy_(shared_branch.b.weight[oc_idx : oc_idx + module.weight.shape[0]])
oc_idx += module.weight.shape[0]
module.weight.data.sub_(branch.get_effective_weight().view(module.weight.data.shape))
branch.as_hook().register(module)
else:
module.weight.data.sub_(shared_branch.get_effective_weight().view(module.weight.data.shape))
shared_branch.as_hook().register(module)
del shared_branch
gc.collect()
torch.cuda.empty_cache()
@torch.inference_mode()
def update_diffusion_block_weight_quantizer_state_dict(
layer: DiffusionModuleStruct | DiffusionBlockStruct,
config: DiffusionQuantConfig,
quantizer_state_dict: dict[str, dict[str, torch.Tensor | float | None]],
layer_cache: dict[str, IOTensorsCache],
layer_kwargs: dict[str, tp.Any],
):
"""Update the state dict of the weight quantizers for a block of a diffusion model.
Args:
layer (`DiffusionModuleStruct` or `DiffusionBlockStruct`):
The block to update.
config (`DiffusionQuantConfig`):
The quantization configuration.
quantizer_state_dict (`dict[str, dict[str, torch.Tensor | float | None]]`):
The state dict of the weight quantizers.
layer_cache (`dict[str, IOTensorsCache]`):
The cache of the layer.
layer_kwargs (`dict[str, tp.Any]`):
The keyword arguments for the layer.
"""
logger = tools.logging.getLogger(f"{__name__}.WeightQuant")
logger.debug("- Calibrating weights: block %s", layer.name)
tools.logging.Formatter.indent_inc()
for module_key, module_name, module, parent, field_name in layer.named_key_modules():
if field_name.endswith(("q_proj", "k_proj")):
assert isinstance(parent, DiffusionAttentionStruct)
if parent.parent.parallel and parent.idx == 0:
eval_module = parent.parent.module
eval_name = parent.parent.name
eval_kwargs = layer_kwargs
else:
eval_module = parent.module
eval_name = parent.name
eval_kwargs = parent.filter_kwargs(layer_kwargs)
if parent.is_joint_attn() and "add_" in field_name:
eval_module = wrap_joint_attn(eval_module, indexes=1)
else:
eval_module, eval_name, eval_kwargs = module, module_name, None
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key)
if quantizer.is_enabled():
if module_name not in quantizer_state_dict:
logger.debug("- Calibrating %s.weight quantizer", module_name)
quantizer.calibrate_dynamic_range(
module=module,
inputs=layer_cache[module_name].inputs if layer_cache else None,
eval_inputs=layer_cache[eval_name].inputs if layer_cache else None,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
)
quantizer_state_dict[module_name] = quantizer.state_dict()
gc.collect()
torch.cuda.empty_cache()
else:
logger.debug("- Loading %s.weight quantizer", module_name)
else:
logger.debug("- Skipping %s.weight", module_name)
if module_name in quantizer_state_dict:
quantizer_state_dict.pop(module_name)
tools.logging.Formatter.indent_dec()
@torch.inference_mode()
def quantize_diffusion_block_weights(
layer: DiffusionModuleStruct | DiffusionBlockStruct,
config: DiffusionQuantConfig,
quantizer_state_dict: dict[str, dict[str, torch.Tensor | float | None]],
layer_cache: dict[str, IOTensorsCache] = None,
return_with_scale_state_dict: bool = False,
) -> dict[str, torch.Tensor | float | None]:
"""Quantize the weights of a block of a diffusion model.
Args:
layer (`DiffusionModuleStruct` or `DiffusionBlockStruct`):
The block to quantize.
config (`DiffusionQuantConfig`):
The quantization configuration.
quantizer_state_dict (`dict[str, dict[str, torch.Tensor | float | None]]`):
The state dict of the weight quantizers.
layer_cache (`dict[str, IOTensorsCache]`, *optional*, defaults to `None`):
The cache of the layer.
return_with_scale_state_dict (`bool`, *optional*, defaults to `False`):
Whether to return the scale state dict.
Returns:
`dict[str, torch.Tensor | float | None]`:
The scale state dict.
"""
logger = tools.logging.getLogger(f"{__name__}.WeightQuant")
logger.debug("- Quantizing weights: block %s", layer.name)
layer_cache = layer_cache or {}
scale_state_dict: dict[str, torch.Tensor | float | None] = {}
tools.logging.Formatter.indent_inc()
for module_key, module_name, module, _, _ in layer.named_key_modules():
if module_name in quantizer_state_dict:
param_name = f"{module_name}.weight"
logger.debug("- Quantizing %s", param_name)
config_wgts = config.wgts
if config.enabled_extra_wgts and config.extra_wgts.is_enabled_for(module_key):
config_wgts = config.extra_wgts
logger.debug(" + quant_dtype: %s", str(config_wgts.dtype))
logger.debug(" + group_shape: %s", str(config_wgts.group_shapes))
logger.debug(" + scale_dtype: %s", str(config_wgts.scale_dtypes))
quantizer = DiffusionWeightQuantizer(config_wgts, develop_dtype=config.develop_dtype, key=module_key)
quantizer.load_state_dict(quantizer_state_dict[module_name], device=module.weight.device)
result = quantizer.quantize(
module.weight.data,
inputs=layer_cache[module_name].inputs.front() if layer_cache else None,
return_with_dequant=True,
return_with_quant=return_with_scale_state_dict,
)
if (
config.wgts.enabled_low_rank
and config.wgts.low_rank.is_enabled_for(module_key)
and config.wgts.low_rank.compensate
and config.wgts.low_rank.num_iters <= 1
):
logger.debug("- Adding compensate low-rank branch to %s (side)", module_name)
LowRankBranch(
in_features=module.weight.shape[1],
out_features=module.weight.shape[0],
rank=config.wgts.low_rank.rank,
weight=module.weight.data - result.data,
).as_hook().register(module)
module.weight.data = result.data
if return_with_scale_state_dict:
scale_state_dict.update(result.scale.state_dict(f"{param_name}.scale"))
zero_name = "scaled_zero" if config.wgts.zero_point is ZeroPointDomain.PostScale else "zero"
if isinstance(result.zero, torch.Tensor):
scale_state_dict[f"{param_name}.{zero_name}"] = result.zero.to("cpu")
else:
scale_state_dict[f"{param_name}.{zero_name}"] = result.zero
del result
gc.collect()
torch.cuda.empty_cache()
tools.logging.Formatter.indent_dec()
return scale_state_dict
@torch.inference_mode()
def quantize_diffusion_weights(
model: nn.Module | DiffusionModelStruct,
config: DiffusionQuantConfig,
quantizer_state_dict: dict[str, dict[str, torch.Tensor | float | None]] | None = None,
branch_state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
return_with_scale_state_dict: bool = False,
) -> tuple[
dict[str, dict[str, torch.Tensor | float | None]],
dict[str, dict[str, torch.Tensor]],
dict[str, torch.Tensor | float | None],
]:
"""Quantize the weights of a diffusion model.
Args:
model (`nn.Module` or `DiffusionModelStruct`):
The diffusion model to quantize.
config (`DiffusionQuantConfig`):
The quantization configuration.
quantizer_state_dict (`dict[str, dict[str, torch.Tensor | float | None]]`, *optional*, defaults to `None`):
The state dict of the weight quantizers.
branch_state_dict (`dict[str, dict[str, torch.Tensor]]`, *optional*, defaults to `None`):
The state dict of the low-rank branches.
return_with_scale_state_dict (`bool`, *optional*, defaults to `False`):
Whether to return the scale state dict.
Returns:
`tuple[
dict[str, dict[str, torch.Tensor | float | None]],
dict[str, dict[str, torch.Tensor]],
dict[str, torch.Tensor | float | None]
]`:
The state dict of the weight quantizers, the state dict of the low-rank branches, and the scale state dict.
"""
logger = tools.logging.getLogger(f"{__name__}.WeightQuant")
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
quantizer_state_dict = quantizer_state_dict or {}
branch_state_dict = branch_state_dict or {}
if config.wgts.enabled_low_rank and (not config.wgts.low_rank.compensate or config.wgts.low_rank.num_iters > 1):
logger.info("* Adding low-rank branches to weights")
tools.logging.Formatter.indent_inc()
with tools.logging.redirect_tqdm():
if branch_state_dict:
for _, layer in tqdm(
model.get_named_layers(skip_pre_modules=True, skip_post_modules=True).items(),
desc="adding low-rank branches",
leave=False,
dynamic_ncols=True,
):
calibrate_diffusion_block_low_rank_branch(
layer=layer, config=config, branch_state_dict=branch_state_dict
)
else:
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader().iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model, config),
skip_pre_modules=True,
skip_post_modules=True,
),
desc="calibrating low-rank branches",
leave=False,
total=model.num_blocks,
dynamic_ncols=True,
):
calibrate_diffusion_block_low_rank_branch(
layer=layer,
config=config,
branch_state_dict=branch_state_dict,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
)
tools.logging.Formatter.indent_dec()
skip_pre_modules = all(key in config.wgts.skips for key in model.get_prev_module_keys())
skip_post_modules = all(key in config.wgts.skips for key in model.get_post_module_keys())
with tools.logging.redirect_tqdm():
if not quantizer_state_dict:
if config.wgts.needs_calib_data:
iterable = config.calib.build_loader().iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model, config),
skip_pre_modules=skip_pre_modules,
skip_post_modules=skip_post_modules,
)
else:
iterable = map( # noqa: C417
lambda kv: (kv[0], (kv[1], {}, {})),
model.get_named_layers(
skip_pre_modules=skip_pre_modules, skip_post_modules=skip_post_modules
).items(),
)
for _, (layer, layer_cache, layer_kwargs) in tqdm(
iterable,
desc="calibrating weight quantizers",
leave=False,
total=model.num_blocks + int(not skip_post_modules) + int(not skip_pre_modules) * 3,
dynamic_ncols=True,
):
update_diffusion_block_weight_quantizer_state_dict(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
)
scale_state_dict: dict[str, torch.Tensor | float | None] = {}
if config.wgts.enabled_gptq:
iterable = config.calib.build_loader().iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model, config),
skip_pre_modules=skip_pre_modules,
skip_post_modules=skip_post_modules,
)
else:
iterable = map( # noqa: C417
lambda kv: (kv[0], (kv[1], {}, {})),
model.get_named_layers(skip_pre_modules=skip_pre_modules, skip_post_modules=skip_post_modules).items(),
)
for _, (layer, layer_cache, _) in tqdm(
iterable,
desc="quantizing weights",
leave=False,
total=model.num_blocks + int(not skip_post_modules) + int(not skip_pre_modules) * 3,
dynamic_ncols=True,
):
layer_scale_state_dict = quantize_diffusion_block_weights(
layer=layer,
config=config,
layer_cache=layer_cache,
quantizer_state_dict=quantizer_state_dict,
return_with_scale_state_dict=return_with_scale_state_dict,
)
scale_state_dict.update(layer_scale_state_dict)
return quantizer_state_dict, branch_state_dict, scale_state_dict
@torch.inference_mode()
def load_diffusion_weights_state_dict(
model: nn.Module | DiffusionModelStruct,
config: DiffusionQuantConfig,
state_dict: dict[str, torch.Tensor],
branch_state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
) -> None:
"""Load the state dict of the weights of a diffusion model.
Args:
model (`nn.Module` or `DiffusionModelStruct`):
The diffusion model to load the weights.
config (`DiffusionQuantConfig`):
The quantization configuration.
state_dict (`dict[str, torch.Tensor]`):
The state dict of the weights.
branch_state_dict (`dict[str, dict[str, torch.Tensor]]`):
The state dict of the low-rank branches.
"""
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
if config.enabled_wgts and config.wgts.enabled_low_rank:
assert branch_state_dict is not None
for _, layer in tqdm(
model.get_named_layers(skip_pre_modules=True, skip_post_modules=True).items(),
desc="adding low-rank branches",
leave=False,
dynamic_ncols=True,
):
calibrate_diffusion_block_low_rank_branch(layer=layer, config=config, branch_state_dict=branch_state_dict)
model.module.load_state_dict(state_dict)
gc.collect()
torch.cuda.empty_cache()
================================================
FILE: deepcompressor/app/diffusion/utils.py
================================================
import os
import random
import numpy as np
import torch
from PIL import Image
from deepcompressor.utils.common import hash_str_to_int
__all__ = ["get_control"]
def update_mask(mask: np.ndarray, x: int, y: int, radius: int | float):
mask = mask.copy()
H, W = mask.shape
for i in range(H):
for j in range(W):
if (j - x) ** 2 + (i - y) ** 2 <= radius**2:
mask[i, j] = True
return mask
def generate_mask(
masked_ratio_range: tuple[int, int], size: int | tuple[int, int], seed: int | None = None, eps=1e-2
) -> np.ndarray:
if seed is not None:
random.seed(seed)
masked_ratio = random.randint(masked_ratio_range[0], masked_ratio_range[1]) / 100
if isinstance(size, int):
size = (size, size)
assert len(size) == 2
height, width = size
mask = np.zeros((height, width), dtype=bool)
while True:
radius = random.randint(16, min(height, width) // 2)
x = random.randint(0, width - 1)
y = random.randint(0, height - 1)
new_mask = update_mask(mask, x, y, radius)
if new_mask.sum() / (height * width) <= masked_ratio + eps:
mask = new_mask
if mask.sum() / (height * width) >= masked_ratio - eps:
break
return mask
def center_crop_and_resize(image: Image.Image, target_size: int | tuple[int, int]) -> Image.Image:
if isinstance(target_size, int):
target_size = (target_size, target_size)
else:
assert len(target_size) == 2
target_width, target_height = target_size
width, height = image.size
if width / height > target_width / target_height:
new_width = height * target_width / target_height
left = round((width - new_width) / 2)
right = round(left + new_width)
image = image.crop((left, 0, right, height))
elif width / height < width / height:
new_height = width * target_height / target_width
top = round((height - new_height) / 2)
bottom = round(top + new_height)
image = image.crop((0, top, width, bottom))
width, height = image.size
if width != target_width or height != target_height:
image = image.resize((target_width, target_height), Image.Resampling.BICUBIC)
return image
def get_control( # noqa: C901
task: str,
images: Image.Image | list[Image.Image],
names: str | list[str] | None = None,
data_root: str | None = None,
device: str | torch.device = "cuda",
**kwargs,
) -> Image.Image | list[Image.Image] | tuple[Image.Image, Image.Image] | tuple[list[Image.Image], list[Image.Image]]:
size = kwargs.get("size", 1024)
if isinstance(size, int):
size = (size, size)
assert len(size) == 2
image_batch = [images] if isinstance(images, Image.Image) else images
if isinstance(names, str):
names = [names]
if task == "canny-to-image":
processor = kwargs.get("processor", None)
control_images = []
for i, image in enumerate(image_batch):
if data_root is not None and names is not None:
data_path = os.path.join(data_root, "canny_images", f"{names[i]}.png")
if os.path.exists(data_path):
control_images.append(Image.open(data_path))
continue
if processor is None:
from controlnet_aux import CannyDetector
processor = CannyDetector()
image = center_crop_and_resize(image, size)
control_image = processor(
image, low_threshold=50, high_threshold=200, detect_resolution=max(size), image_resolution=max(size)
)
control_images.append(control_image)
if isinstance(images, Image.Image):
return control_images[0]
return control_images
elif task == "depth-to-image":
processor = kwargs.get("processor", None)
control_images = []
for i, image in enumerate(image_batch):
if data_root is not None and names is not None:
data_path = os.path.join(data_root, "depth_images", f"{names[i]}.png")
if os.path.exists(data_path):
control_images.append(Image.open(data_path))
continue
if processor is None:
from image_gen_aux import DepthPreprocessor
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf").to(device)
image = center_crop_and_resize(image, size)
control_image = processor(image.convert("RGB"))[0].convert("RGB")
control_images.append(control_image)
if isinstance(images, Image.Image):
return control_images[0]
return control_images
elif task == "inpainting":
control_images, mask_images = [], []
for i, image in enumerate(image_batch):
name = None if names is None else names[i]
if data_root is not None and name is not None:
cropped_image_path = os.path.join(data_root, "cropped_images", f"{name}.png")
mask_path = os.path.join(data_root, "mask_images", f"{name}.png")
if os.path.exists(cropped_image_path) and os.path.exists(mask_path):
control_images.append(Image.open(cropped_image_path).convert("RGB"))
mask_images.append(Image.open(mask_path))
continue
image = center_crop_and_resize(image, size)
control_images.append(image.convert("RGB"))
if names is not None:
seed = hash_str_to_int(names[i])
else:
seed = None
mask = generate_mask((5, 60), size, seed=seed)
mask_image = Image.fromarray(mask.astype(np.uint8) * 255)
mask_images.append(mask_image)
if isinstance(images, Image.Image):
return control_images[0], mask_images[0]
return control_images, mask_images
else:
raise ValueError(f"Unsupported task: {task}")
================================================
FILE: deepcompressor/app/llm/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/app/llm/cache/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/app/llm/cache/config.py
================================================
# -*- coding: utf-8 -*-
"""LLM quantization cache configuration."""
from dataclasses import dataclass, field
from omniconfig import configclass
from deepcompressor.utils.config.path import BasePathConfig
__all__ = ["LlmQuantCacheConfig", "LlmCacheConfig"]
@configclass
@dataclass
class LlmQuantCacheConfig(BasePathConfig):
"""Large language model quantization cache path.
Args:
rotation (`str`, *optional*, default=`""`):
The rotation matrix cache path.
reorder (`str`, *optional*, default=`""`):
The reorder channel indexes cache path.
smooth (`str`, *optional*, default=`""`):
The smoothing scales cache path.
wgts (`str`, *optional*, default=`""`):
The weight quantizers state dict cache path.
acts (`str`, *optional*, default=`""`):
The activation quantizers state dict cache path.
"""
rotation: str = ""
reorder: str = ""
smooth: str = ""
wgts: str = ""
acts: str = ""
@configclass
@dataclass
class LlmCacheConfig:
"""LLM quantization cache configuration.
Attributes:
root (`str`, *optional*, default=`""`):
The root directory path for the cache.
dirpath (`LlmQuantCacheConfig`, *optional*, default=`LlmQuantCacheConfig()`):
The directory paths for the cache.
path (`LlmQuantCacheConfig`, *optional*, default=`LlmQuantCacheConfig()`):
The file paths for the cache.
"""
root: str = field(default="")
dirpath: LlmQuantCacheConfig = field(init=False, default_factory=LlmQuantCacheConfig)
path: LlmQuantCacheConfig = field(default_factory=LlmQuantCacheConfig)
================================================
FILE: deepcompressor/app/llm/config.py
================================================
# -*- coding: utf-8 -*-
"""Configurations for evaluating a large language model."""
import os
import random
from dataclasses import dataclass, field
import numpy as np
import omniconfig
import torch
from omniconfig import ConfigParser, configclass
from deepcompressor.data.utils import ScaleUtils
from deepcompressor.utils.config.output import OutputConfig
from .cache.config import LlmCacheConfig, LlmQuantCacheConfig
from .eval.config import LlmEvalConfig
from .model.config import LlmModelConfig
from .quant.config import LlmQuantConfig
__all__ = [
"LlmPtqRunConfig",
"LlmCacheConfig",
"LlmQuantCacheConfig",
"LlmEvalConfig",
"LlmModelConfig",
"LlmQuantConfig",
]
@configclass
@dataclass
class LlmPtqRunConfig:
"""Top-level config of post-training quantization for a large language model.
Args:
cache (`LlmCacheConfig`):
Large language model quantization cache path configuration.
output (`OutputConfig`):
Output directory configuration.
model (`LlmModelConfig`):
Large language model configuration.
eval (`LlmEvalConfig`):
Large language model evaluation configuration.
quant (`LlmQuantConfig`):
Large language model quantization configuration.
seed (`int`, *optional*, defaults to `12345`):
Random seed.
skip_eval (`bool`, *optional*, defaults to `False`):
Whether to skip evaluation.
load_model (`str`, *optional*, defaults to `""`):
Directory path to load the model checkpoint.
save_model (`str`, *optional*, defaults to `""`):
Directory path to save the model checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the quantization cache on save.
"""
cache: LlmCacheConfig
output: OutputConfig
model: LlmModelConfig
eval: LlmEvalConfig
quant: LlmQuantConfig = field(metadata={omniconfig.ARGPARSE_KWARGS: {"prefix": ""}})
seed: int = 12345
skip_eval: bool = False
load_from: str = ""
save_model: str = ""
copy_on_save: bool = False
def __post_init__(self): # noqa: C901
# region set scale default dtype
if self.quant.enabled_wgts:
self.quant.wgts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.quant.wgts.scale_dtypes, default_dtype=self.model.dtype)
)
if self.quant.enabled_ipts:
self.quant.ipts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.quant.ipts.scale_dtypes, default_dtype=self.model.dtype)
)
if self.quant.enabled_opts:
self.quant.opts.scale_dtypes = tuple(
ScaleUtils.infer_scale_dtypes(self.quant.opts.scale_dtypes, default_dtype=self.model.dtype)
)
# endregion
# region set num_gpus and batch_size for auto parallelism of large models
self.eval.num_gpus = min(torch.cuda.device_count(), self.eval.num_gpus)
if self.model.size < 50:
self.eval.batch_size = min(8, self.eval.batch_size)
elif self.model.size < 100:
self.eval.batch_size = min(4, self.eval.batch_size)
else:
self.eval.batch_size = min(1, self.eval.batch_size)
# endregion
if self.quant.is_enabled():
if self.cache.path.is_all_empty():
self.cache.dirpath = self.quant.generate_cache_dirpath(
root=self.cache.root, seed=self.seed, default_dtype=self.model.dtype
)
self.cache.path = self.cache.dirpath.clone().add_children(f"{self.model.name}.pt")
else:
self.cache.dirpath = self.cache.path.clone().to_dirpath()
if self.output.dirname == "default":
self.output.dirname = self.quant.generate_default_dirname()
self.output.dirpath = os.path.join(
self.output.root,
"llm",
self.model.family,
self.model.name,
*self.quant.generate_dirnames(default_dtype=self.model.dtype)[:-1],
self.quant.generate_calib_dirname(),
self.output.dirname,
)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
np.random.seed(self.seed)
@classmethod
def get_parser(cls) -> ConfigParser:
"""Get a parser for evaluating a large language model.
Returns:
`ConfigParser`: A parser for evaluating a large language model.
"""
parser = ConfigParser("Evaluate a large language model")
parser.add_config(cls)
return parser
================================================
FILE: deepcompressor/app/llm/eval/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/app/llm/eval/base.py
================================================
# -*- coding: utf-8 -*-
"""Language model evaluator base."""
from abc import ABC, abstractmethod
from transformers import PreTrainedModel, PreTrainedTokenizer
__all__ = ["LlmEvaluatorBase"]
class LlmEvaluatorBase(ABC):
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
self.model, self.tokenizer = model, tokenizer
@abstractmethod
def filter_tasks(self, tasks: list[str]) -> list[str]:
"""Filter the tasks to only include supported tasks."""
...
@abstractmethod
def evaluate(self, tasks: list[str], **kwargs) -> dict[str, dict[str, dict[str, float]]]:
"""Evaluate the model on the given tasks."""
...
================================================
FILE: deepcompressor/app/llm/eval/config.py
================================================
# -*- coding: utf-8 -*-
"""Language model evaluation config."""
import random
import typing as tp
from dataclasses import dataclass, field
import numpy as np
import omniconfig
import torch
from omniconfig import configclass
from transformers import PreTrainedModel, PreTrainedTokenizer
from deepcompressor.utils import tools
from .custom import LlmCustomEvaluator
from .lm_eval import LmevalEvaluator
from .longbench import LongbenchEvaluator
__all__ = ["LlmEvalConfig"]
@configclass
@dataclass
class LlmEvalConfig:
"""Large language model evaluation configuration.
Attributes:
num_gpus (`int`, *optional*, defaults to `1`):
The number of GPUs to use.
batch_size (`int`, *optional*, defaults to `1`):
The batch size used for inference.
tasks (`list[str]`, *optional*, defaults to `["zero-shot"]`):
Task names, e.g. wikitext, hellaswag, piqa, winogrande.
max_seq_length (`int`, *optional*, defaults to `-4096`):
Maximum sequence length.
If negative, sequence lengths smaller than or equal to the absolute value are used.
evaluators (`list[str]`, *optional*, defaults to `["gptq"]`):
Evaluators names.
num_shot (`int`, *optional*, defaults to `None`):
The number of shots for few-shot evaluation.
fewshot_as_multiturn (`bool`, *optional*, defaults to `False`):
Whether to treat few-shot evaluation as multi-turn.
apply_chat_template (`bool`, *optional*, defaults to `False`):
Whether to apply chat template for evaluation.
"""
num_gpus: int = field(default=1, metadata={omniconfig.ARGPARSE_ARGS: ("--num-gpus", "-n")})
batch_size: int = 1
tasks: list[str] = field(
default_factory=lambda: ["zero-shot"],
metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
)
max_seq_length: int = -4096
evaluators: list[str] = field(
default_factory=lambda: ["gptq"], metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}}
)
num_shot: int | None = None
fewshot_as_multiturn: bool = False
apply_chat_template: bool = False
def __post_init__(self):
if "zero-shot" in self.tasks:
self.tasks.remove("zero-shot")
self.tasks.extend(("wikitext", "hellaswag", "piqa", "winogrande", "arc_easy", "arc_challenge"))
self.tasks = sorted({tast.lower() for tast in self.tasks})
self.evaluators = sorted({evaluator.lower() for evaluator in self.evaluators})
for evaluator in self.evaluators:
assert evaluator in ("lm_eval", "gptq", "longbench"), f"Invalid evaluator: {evaluator}"
if len(self.evaluators) == 1 and self.evaluators[0] == "gpq":
self.tasks = [task for task in self.tasks if task.startswith(("wikitext", "pile", "gsm8k"))]
assert len(self.tasks) > 0, "No valid tasks for GPTQ evaluation"
def evaluate(
self,
model: PreTrainedModel,
/,
tokenizer: PreTrainedTokenizer,
model_name: str,
eos_token_ids: tp.Sequence[int] = (),
output_dirpath: str = "",
) -> dict[str, dict[int, dict[str, dict[tp.Any, dict[str, tp.Any]]]]]:
"""Evaluate the model.
Args:
model (`PreTrainedModel`):
The model.
tokenizer (`PreTrainedTokenizer`):
The tokenizer.
model_name (`str`):
The name of the model.
eos_token_ids (`Sequence[int]`, *optional*, defaults to `()`):
The EOS token IDs.
Returns:
`dict[str, dict[int, dict[str, dict[tp.Any, dict[str, tp.Any]]]]]`:
The evaluation results.
- The first key is the evaluator name.
- The second key is the maximum sequence length.
- The third key is the content name, e.g., "results", "versions", "config".
- The fourth key is the task name for "results".
"""
logger = tools.logging.getLogger(f"{__name__}.LlmEval")
tools.logging.Formatter.indent_inc()
tools.logging.Formatter.indent_dec()
lm_max_seq_length = get_max_seq_length(model, tokenizer)
max_seq_lengths = {2048, 4096, lm_max_seq_length}
if self.max_seq_length < 0:
if self.max_seq_length == -1:
max_seq_length = lm_max_seq_length
else:
max_seq_length = min(lm_max_seq_length, -self.max_seq_length)
max_seq_lengths = [length for length in sorted(max_seq_lengths) if length <= max_seq_length]
elif self.max_seq_length == 0:
max_seq_lengths = [lm_max_seq_length]
else:
max_seq_lengths = [self.max_seq_length]
results = {}
for evaluator_name in self.evaluators:
logger.info(f"- Evaluator: {evaluator_name}")
tasks = list(self.tasks)
if evaluator_name == "gptq":
evaluator = LlmCustomEvaluator(model=model, tokenizer=tokenizer)
elif evaluator_name == "lm_eval":
evaluator = LmevalEvaluator(model=model, tokenizer=tokenizer, batch_size=self.batch_size)
elif evaluator_name == "longbench":
evaluator = LongbenchEvaluator(
model=model,
tokenizer=tokenizer,
model_name=model_name,
eos_token_ids=eos_token_ids,
output_dirpath=output_dirpath,
)
else:
raise ValueError(f"Invalid evaluator: {evaluator_name}")
logger.info(f"- Tasks: {tasks}")
logger.info(f"- Batch_size: {self.batch_size}")
rsts = {}
tools.logging.Formatter.indent_inc()
for max_seq_length in max_seq_lengths:
logger.info(f"+ Max_seq_length: {max_seq_length}")
tools.logging.Formatter.indent_inc()
tools.logging.Formatter.indent_inc()
# set seed
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
# evaluate
rst = evaluator.evaluate(
tasks=tasks,
max_length=max_seq_length,
num_shot=self.num_shot,
fewshot_as_multiturn=self.fewshot_as_multiturn,
apply_chat_template=self.apply_chat_template,
)
rst["model"] = model_name
tools.logging.Formatter.indent_dec()
logger.info("- Results:")
tools.logging.Formatter.indent_inc()
tools.logging.info(self.make_table(rst), logger=logger)
tools.logging.Formatter.indent_dec()
rsts[max_seq_length] = rst
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
results[evaluator_name] = rsts
return results
@staticmethod
def make_table(rst: dict[str, dict[tp.Any, dict[str, tp.Any]]]) -> str:
"""Generate table of results.
Args:
results (`dict[str, dict[tp.Any, dict[str, tp.Any]]]`):
The evaluation results.
Returns:
`str`:
The string representation of the results in a table.
"""
from pytablewriter import MarkdownTableWriter
md_writer = MarkdownTableWriter()
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
values = []
for k, dic in rst["results"].items():
version = rst["versions"][k]
for m, v in dic.items():
if "_stderr" in m:
continue
mse = "_stderr,".join(m.split(","))
appended = False
if mse in dic:
se = dic[mse]
if isinstance(se, (int, float)):
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
appended = True
if not appended and isinstance(v, (int, float)):
values.append([k, version, m, "%.4f" % v, "", ""])
k = ""
version = ""
md_writer.value_matrix = values
return md_writer.dumps()
def get_max_seq_length(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, default_seq_length: int = 2048) -> int:
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(model.config, attr):
return getattr(model.config, attr)
if hasattr(tokenizer, "model_max_length"):
if tokenizer.model_max_length == 1000000000000000019884624838656:
return default_seq_length
return tokenizer.model_max_length
return default_seq_length
================================================
FILE: deepcompressor/app/llm/eval/custom.py
================================================
# -*- coding: utf-8 -*-
"""Language model customized evaluator."""
import math
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer
from .base import LlmEvaluatorBase
__all__ = ["LlmCustomEvaluator"]
class LlmCustomEvaluator(LlmEvaluatorBase):
def filter_tasks(self, tasks: list[str]) -> list[str]:
"""Filter the tasks to only include supported tasks."""
return [task for task in tasks if task.startswith(("wikitext", "pile"))]
def evaluate(
self, tasks: list[str], max_length: int | None = None, **kwargs
) -> dict[str, dict[str, dict[str, float]]]:
"""Evaluate the model on the given tasks.
Args:
tasks (`list[str]`): List of tasks to evaluate on.
max_length (`int`, optional, defaults to `None`): Maximum length for the model.
Returns:
dict[str, dict[str, dict[str, float]]]: Evaluation results `{"results": {"task": {"metric": score}}}`.
"""
result = {"results": {}, "versions": {}}
for task in tasks:
result["results"][task] = {
"word_perplexity": _eval_ppl_with_gptq_evaluator(
self.model, self.tokenizer, task=task, seq_length=max_length
)
}
result["versions"][task] = 1
return result
def _eval_ppl_with_gptq_evaluator(
model: PreTrainedModel,
/,
tokenizer: PreTrainedTokenizer,
task: str,
seq_length: int = 2048,
max_num_samples: int = -1,
) -> float:
"""Evaluate the perplexity of a model on a task using GPTQ style evaluation.
Args:
model (`PreTrainedModel`):
The model.
tokenizer (`PreTrainedTokenizer`):
The tokenizer.
task (`str`):
The task name.
seq_length (`int`, *optional*, defaults to `2048`):
The sequence length.
max_num_samples (`int`, *optional*, defaults to `-1`):
The maximum number of samples to evaluate.
Returns:
float: The perplexity.
"""
assert seq_length > 0, "seq_length must be positive"
if task.startswith("wikitext"):
test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
test_dataset = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt")
elif task.startswith("pile"):
test_dataset = load_dataset("pile", task, split="test")
test_dataset = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt")
else:
raise ValueError(f"Invalid task: {task}")
test_dataset = test_dataset.input_ids.to(model.device)
num_samples = test_dataset.numel() // seq_length
if max_num_samples > 0:
num_samples = min(num_samples, max_num_samples)
model = model.eval()
nlls = []
for i in tqdm(range(num_samples), desc=f"evaluating on {task} with seq_length {seq_length}", dynamic_ncols=True):
batch = test_dataset[:, (i * seq_length) : ((i + 1) * seq_length)]
with torch.inference_mode():
shift_logits = model(batch.to(model.device)).logits[:, :-1, :].contiguous().float()
shift_labels = batch[:, 1:]
loss = nn.CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seq_length
nlls.append(neg_log_likelihood)
return math.exp(sum(nlls) / (num_samples * seq_length))
================================================
FILE: deepcompressor/app/llm/eval/lm_eval.py
================================================
# -*- coding: utf-8 -*-
"""Language model evaluator using lm_eval."""
import lm_eval
import lm_eval.models
from transformers import PreTrainedModel, PreTrainedTokenizer
from .base import LlmEvaluatorBase
__all__ = ["LmevalEvaluator"]
class LmevalEvaluator(LlmEvaluatorBase):
def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, batch_size: int = 1):
super().__init__(model=model, tokenizer=tokenizer)
self.lm = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer, batch_size=batch_size)
def filter_tasks(self, tasks: list[str]) -> list[str]:
"""Filter the tasks to only include supported tasks."""
return tasks
def evaluate(
self,
tasks: list[str],
max_length: int | None = None,
num_shot: int | None = None,
fewshot_as_multiturn: bool = False,
apply_chat_template: bool = False,
**kwargs,
) -> dict[str, dict[str, dict[str, float]]]:
"""Evaluate the model on the given tasks.
Args:
tasks (`list[str]`): List of tasks to evaluate on.
max_length (`int`, optional, defaults to `None`): Maximum length for the model.
Returns:
dict[str, dict[str, dict[str, float]]]: Evaluation results `{"results": {"task": {"metric": score}}}`.
"""
self.lm._max_length = max_length
result = lm_eval.evaluator.simple_evaluate(
model=self.lm,
tasks=tasks,
verbosity="ERROR",
num_fewshot=num_shot,
fewshot_as_multiturn=fewshot_as_multiturn,
apply_chat_template=apply_chat_template,
**kwargs,
)
self.lm._max_length = None
result.pop("samples", None)
result.pop("config", None)
return result
================================================
FILE: deepcompressor/app/llm/eval/longbench/__init__.py
================================================
from .eval import LongbenchEvaluator, LongbenchScorer
================================================
FILE: deepcompressor/app/llm/eval/longbench/eval.py
================================================
# -*- coding: utf-8 -*-
"""Language model evaluator for LongBench."""
import json
import os
import typing as tp
import numpy as np
import torch
import torch.utils.data
from datasets import load_dataset
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer
from deepcompressor.utils import tools
from ..base import LlmEvaluatorBase
from .metrics import (
classification_score,
code_sim_score,
count_score,
qa_f1_score,
qa_f1_zh_score,
retrieval_score,
retrieval_zh_score,
rouge_score,
rouge_zh_score,
)
__all__ = ["LongbenchEvaluator"]
class LongbenchEvaluator(LlmEvaluatorBase):
task2maxlen: dict[str, int] = {
"narrativeqa": 128,
"qasper": 128,
"multifieldqa_en": 64,
"multifieldqa_zh": 64,
"hotpotqa": 32,
"2wikimqa": 32,
"musique": 32,
"dureader": 128,
"gov_report": 512,
"qmsum": 512,
"multi_news": 512,
"vcsum": 512,
"trec": 64,
"triviaqa": 32,
"samsum": 128,
"lsht": 64,
"passage_count": 32,
"passage_retrieval_en": 32,
"passage_retrieval_zh": 32,
"lcc": 64,
"repobench-p": 64,
}
task2prompt: dict[str, str] = None
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
model_name: str,
eos_token_ids: tp.Sequence[int],
output_dirpath: str = "",
task2maxlen: dict[str, int] = None,
task2prompt: dict[str, str] = None,
):
super().__init__(model=model, tokenizer=tokenizer)
self.model_name = model_name
self.eos_token_ids = eos_token_ids
if task2maxlen is not None:
self.task2maxlen = task2maxlen
if task2prompt is not None:
self.task2prompt = task2prompt
self.output_dirpath = output_dirpath
self.logger = tools.logging.getLogger(__name__)
def filter_tasks(self, tasks: list[str]) -> list[str]:
"""Filter the tasks to only include supported tasks."""
if "longbench-e" in tasks:
return ["longbench-e"]
if "longbench" in tasks:
return sorted(self.task2maxlen.keys(), key=lambda x: self.task2maxlen[x])
return sorted([task for task in tasks if task in self.task2maxlen], key=lambda x: self.task2maxlen[x])
def evaluate(self, tasks: list[str], max_length: int, **kwargs) -> dict[str, dict[str, dict[str, float]]]:
"""Evaluate the model on the given tasks."""
...
tools.logging.Formatter.indent_inc()
longbench_e = False
if "longbench-e" in tasks:
assert len(tasks) == 1, "LongBench-E should be the only task"
longbench_e = True
tasks = [
"hotpotqa",
"2wikimqa",
"triviaqa",
"passage_count",
"multifieldqa_en",
"trec",
"lcc",
"repobench-p",
"qasper",
"samsum",
"gov_report",
"multi_news",
"passage_retrieval_en",
]
result = {"results": {}, "versions": {}}
for task in tasks:
self.logger.info(f"- Evaluating on {task}")
tools.logging.Formatter.indent_inc()
preds = self.predict(task=task, max_length=max_length)
if not preds:
self.logger.warning(f"No results for {task}")
tools.logging.Formatter.indent_dec()
continue
if self.output_dirpath:
self.logger.info(f"+ Saving results for {task} to {self.output_dirpath}")
os.makedirs(os.path.join(self.output_dirpath, "longbench"), exist_ok=True)
with open(
os.path.join(self.output_dirpath, "longbench", f"{task}.json"),
"w",
encoding="utf-8",
) as f:
for pred in preds:
json.dump(pred, f, ensure_ascii=False)
f.write("\n")
predictions, answers, lengths = [], [], []
for pred in preds:
predictions.append(pred["prediction"])
answers.append(pred["answers"])
lengths.append(pred["length"])
all_classes = preds[0]["all_classes"]
if longbench_e:
scores = LongbenchScorer.scorer_e(
task=task,
predictions=predictions,
answers=answers,
lengths=lengths,
all_classes=all_classes,
)
else:
scores = {
"score": LongbenchScorer.score(
task=task,
predictions=predictions,
answers=answers,
all_classes=all_classes,
)
}
tools.logging.debug(f"+ Scores: {scores}", self.logger)
result["results"][task] = scores
result["versions"][task] = 1
tools.logging.Formatter.indent_dec()
tools.logging.Formatter.indent_dec()
return result
def predict(
self,
task: str,
max_length: int,
max_gen_length: int | None = None,
prompt_format: str = "",
) -> list[dict[str, tp.Any]]:
if max_gen_length is None:
max_gen_length = self.task2maxlen[task]
if prompt_format == "":
prompt_format = self.task2prompt[task]
dataset = load_dataset("THUDM/LongBench", task, split="test")
preds = []
pbar = tqdm(dataset)
tools.logging.Formatter.indent_inc()
for idx, data in enumerate(pbar):
prompt = prompt_format.format(**data)
# truncate to fit max_length
# (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = self.tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
if len(tokenized_prompt) > max_length:
half = int(max_length / 2)
prompt = self.tokenizer.decode(
tokenized_prompt[:half], skip_special_tokens=True
) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
if task not in ("trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"):
# chat models are better off without build prompts on these tasks
prompt = self.build_chat(prompt)
input = self.tokenizer(prompt, truncation=False, return_tensors="pt").to("cuda")
pbar.set_description(f"Generating for {idx}, len={input.input_ids.shape[-1]}")
with torch.no_grad():
output = self.model(input_ids=input.input_ids, past_key_values=None, use_cache=True)
past_key_values = output.past_key_values
pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_content = [pred_token_idx.item()]
for _ in range(max_gen_length - 1):
outputs = self.model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_content += [pred_token_idx.item()]
if pred_token_idx.item() in self.eos_token_ids:
break
pred = self.tokenizer.decode(generated_content, skip_special_tokens=True)
pred = self.post_process(pred)
# tools.logging.debug(f"- Prediction: {pred}", self.logger)
preds.append(
{
"prediction": pred,
"answers": data["answers"],
"all_classes": data["all_classes"],
"length": data["length"],
}
)
# break
tools.logging.Formatter.indent_dec()
return preds
def build_chat(self, prompt):
"""Build chat prompt for chat models."""
if "llama-2" in self.model_name:
prompt = f"[INST]{prompt}[/INST]"
return prompt
def post_process(self, response: str) -> str:
if "xgen" in self.model_name:
response = response.strip().replace("Assistant:", "")
elif "internlm" in self.model_name:
response = response.split("")[0]
elif "llama-3" in self.model_name:
response = response.split(".assistant")[0].split("\n\nQuestion")[0].split("")[0].strip()
elif "llama-2-7b" in self.model_name and "instruct" in self.model_name and "32k" in self.model_name:
response = (
response.split("(Document")[0]
.split("\n\nQuestion")[0]
.split("\n\nAnswer")[0]
.split("(Passage")[0]
.strip()
)
return response
class LongbenchScorer:
task2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"multifieldqa_zh": qa_f1_zh_score,
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"dureader": rouge_zh_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"vcsum": rouge_zh_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"lsht": classification_score,
"passage_retrieval_en": retrieval_score,
"passage_count": count_score,
"passage_retrieval_zh": retrieval_zh_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}
@staticmethod
def score(
task: str,
predictions: tp.Sequence[str],
answers: tp.Sequence[tp.Sequence[str]],
all_classes: tp.Sequence[str],
task2metric: tp.Mapping[str, tp.Callable[[str, str, tp.Any], float]] = None,
) -> float:
if task2metric is None:
task2metric = LongbenchScorer.task2metric
total_score = 0.0
for prediction, ground_truths in zip(predictions, answers, strict=True):
score = 0.0
prediction = (
prediction.split(".assistant")[0]
.split("\n\nQuestion")[0]
.split("")[0]
.split("(Document")[0]
.split("\n\nQuestion")[0]
.split("\n\nAnswer")[0]
.split("(Passage")[0]
.strip()
)
if task in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
if task in ["multifieldqa_zh", "dureader"]:
prediction = prediction.split("问题:")[0].strip()
if task in ["lsht"]:
prediction = prediction.split("新闻内容:")[0].strip()
if task in ["passage_retrieval_zh"]:
prediction = prediction.split("请问")[0].split("提示")[0].strip()
for ground_truth in ground_truths:
score = max(
score,
task2metric[task](prediction, ground_truth, all_classes=all_classes),
)
total_score += score
return round(100 * total_score / len(predictions), 2)
@staticmethod
def scorer_e(
task: str,
predictions: tp.Sequence[str],
answers: tp.Sequence[tp.Sequence[str]],
lengths: tp.Sequence[int],
all_classes: tp.Sequence[str],
task2metric: tp.Mapping[str, tp.Callable[[str, str, tp.Any], float]] = None,
) -> dict[str, float]:
if task2metric is None:
task2metric = LongbenchScorer.task2metric
scores = {"0-4k": [], "4-8k": [], "8k+": []}
for prediction, ground_truths, length in zip(predictions, answers, lengths, strict=True):
score = 0.0
if task in ["trec", "triviaqa", "samsum", "lsht"]:
prediction = prediction.lstrip("\n").split("\n")[0]
for ground_truth in ground_truths:
score = max(
score,
task2metric[task](prediction, ground_truth, all_classes=all_classes),
)
if length < 4000:
scores["0-4k"].append(score)
elif length < 8000:
scores["4-8k"].append(score)
else:
scores["8k+"].append(score)
for key in scores.keys():
scores[key] = round(100 * np.mean(scores[key]), 2)
return scores
# Initialize the evaluator task2prompt by loading the json file
with open(os.path.join(os.path.dirname(__file__), "task2prompt.json")) as f:
LongbenchEvaluator.task2prompt = json.load(f)
================================================
FILE: deepcompressor/app/llm/eval/longbench/metrics.py
================================================
"""LongBench metrics."""
import re
import string
from collections import Counter
import jieba
from fuzzywuzzy import fuzz
from rouge import Rouge
__all__ = [
"classification_score",
"code_sim_score",
"count_score",
"qa_f1_score",
"qa_f1_zh_score",
"retrieval_score",
"retrieval_zh_score",
"rouge_score",
"rouge_zh_score",
]
def normalize_answer(s: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text: str) -> str:
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text: str) -> str:
return " ".join(text.split())
def remove_punc(text: str) -> str:
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
return white_space_fix(remove_articles(remove_punc(s.lower())))
def normalize_zh_answer(s: str) -> str:
"""Lower text and remove punctuation, extra whitespace."""
def white_space_fix(text):
return "".join(text.split())
def remove_punc(text):
exclude = set(
string.punctuation + "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~"
"⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
)
return "".join(ch for ch in text if ch not in exclude)
return white_space_fix(remove_punc(s.lower()))
def count_score(prediction: str, ground_truth: str, **kwargs) -> float:
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)
def retrieval_score(prediction: str, ground_truth: str, **kwargs) -> float:
pattern = r"Paragraph (\d+)"
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
return 0.0 if len(numbers) == 0 else right_num / len(numbers)
def retrieval_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
pattern = r"段落(\d+)"
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
return 0.0 if len(numbers) == 0 else right_num / len(numbers)
def code_sim_score(prediction: str, ground_truth: str, **kwargs) -> float:
all_lines = prediction.lstrip("\n").split("\n")
prediction = ""
for line in all_lines:
if ("`" not in line) and ("#" not in line) and ("//" not in line):
prediction = line
break
return fuzz.ratio(prediction, ground_truth) / 100
def classification_score(prediction: str, ground_truth: str, **kwargs) -> float:
em_match_list = [
class_name
for class_name in kwargs["all_classes"]
if class_name in prediction and not (class_name in ground_truth and class_name != ground_truth)
]
return 1.0 / len(em_match_list) if ground_truth in em_match_list else 0.0
def rouge_score(prediction: str, ground_truth: str, **kwargs) -> float:
try:
scores = Rouge().get_scores([prediction], [ground_truth], avg=True)
except Exception:
return 0.0
return scores["rouge-l"]["f"]
def rouge_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
return rouge_score(prediction, ground_truth)
def f1_score(prediction: str, ground_truth: str, **kwargs) -> float:
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
return (2 * precision * recall) / (precision + recall)
def qa_f1_score(prediction: str, ground_truth: str, **kwargs) -> float:
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens)
def qa_f1_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
return f1_score(prediction_tokens, ground_truth_tokens)
================================================
FILE: deepcompressor/app/llm/eval/longbench/task2prompt.json
================================================
{
"narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
"qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
"multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
"hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
"dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
"gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
"qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
"multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
"vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
"trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
"triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
"samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
"lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
"passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
"passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
"passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
"lcc": "Please complete the code given below. \n{context}Next line of code:\n",
"repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
}
================================================
FILE: deepcompressor/app/llm/model/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/app/llm/model/config.py
================================================
# -*- coding: utf-8 -*-
"""Net configurations."""
import typing as tp
from dataclasses import dataclass, field
import torch
from omniconfig import configclass
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from deepcompressor.data.utils.dtype import eval_dtype
from deepcompressor.utils.config.model import BaseModelConfig
from ..nn.patch import patch_attention, patch_gemma_rms_norm
__all__ = ["LlmModelConfig"]
@configclass
@dataclass
class LlmModelConfig(BaseModelConfig):
"""Arguments for creating a large language model.
Args:
name (`str`):
Name of the model.
path (`str`, *optional*, defaults to `""`):
Path of the model.
root (`str`, *optional*, defaults to `""`):
Root directory path for models.
local_path (`str`, *optional*, defaults to `""`):
Local path of the model.
local_root (`str`, *optional*, defaults to `""`):
Local root directory path for models.
dtype (`torch.dtype`, *optional*, defaults to `None`):
Data type of the model. If not specified, the original data type of the model will be used.
fast_tokenizer (`bool`, *optional*, defaults to `True`):
Whether to use fast tokenizer.
Attributes:
size (`float`):
Size of the model.
variant (`str`):
Variant of the model.
"""
_model_factories: tp.ClassVar[dict[str, tp.Callable[[str], tuple[PreTrainedModel, PreTrainedTokenizer]]]] = {}
size: float = field(init=False)
variant: str = field(init=False)
dtype: torch.dtype = field(default_factory=lambda s=None: eval_dtype(s, with_quant_dtype=False))
use_flash_attn: bool = False
fast_tokenizer: bool = True
orig_dtype: torch.dtype = field(init=False)
def __post_init__(self):
parts = self.name.split("-")
# we first infer the size, it should be a string matching "$\d+[mb]$"
family, size, variant = "", "", ""
for i, part in enumerate(parts):
part = part.lower()
if part[-1] == "m" or part[-1] == "b":
_part = part[:-1].replace("x", "", 1)
if _part.isdigit():
size = part
family = "-".join(parts[:i])
if len(parts) > i + 1:
variant = "-".join(parts[i + 1 :])
break
assert size, f"Cannot infer size from {self.name}"
assert family, f"Cannot infer family from {self.name}"
if not self.family:
self.family = family
self.variant = variant
if size[-1] == "m":
size = float(size[:-1]) / 1000
else:
assert size[-1] == "b"
size = size[:-1]
if "x" in size:
num_experts, expert_gb = size.split("x")
num_experts = int(num_experts)
expert_size = float(expert_gb)
size = num_experts * expert_size
else:
size = float(size)
self.size = size
super().__post_init__()
self.name = self.name.lower()
self.family = self.family.lower()
self.variant = self.variant.lower()
config = AutoConfig.from_pretrained(self.path)
self.orig_dtype = config.torch_dtype
if self.orig_dtype == torch.float32:
self.dtype = self.dtype or torch.float16
elif self.orig_dtype == torch.float16:
self.dtype = self.dtype or torch.float16
elif self.orig_dtype == torch.bfloat16:
self.dtype = self.dtype or torch.bfloat16
else:
raise ValueError(f"Unsupported data type: {self.orig_dtype}")
def build(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Build model and tokenizer.
Args:
dtype (`torch.dtype`, *optional*, defaults to `None`):
Data type of the model.
Returns:
`tuple[PreTrainedModel, PreTrainedTokenizer]`:
Model and tokenizer.
"""
torch_dtype = self.dtype
if self.name in self._model_factories:
return self._model_factories[self.name](
self.path, torch_dtype=torch_dtype, use_fast=self.fast_tokenizer, use_flash_attn=self.use_flash_attn
)
kwargs = {"torch_dtype": torch_dtype}
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
kwargs["device_map"] = "balanced"
return self._default_build(self.path, **kwargs)
@staticmethod
def _default_build(path: str, **kwargs) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Build model and tokenizer.
Args:
dtype (`torch.dtype`, *optional*, defaults to `None`):
Data type of the model.
Returns:
`tuple[PreTrainedModel, PreTrainedTokenizer]`:
Model and tokenizer.
"""
config = AutoConfig.from_pretrained(path)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=kwargs.pop("use_fast", True))
if "use_flash_attn" in kwargs:
use_flash_attn = kwargs.pop("use_flash_attn")
if use_flash_attn:
kwargs["attn_implementation"] = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(path, config=config, **kwargs)
patch_attention(model)
patch_gemma_rms_norm(model)
model.eval()
return model, tokenizer
@classmethod
def register_model_factory(
cls,
names: str | tuple[str, ...],
/,
factory: tp.Callable[[str, torch.dtype], tuple[PreTrainedModel, PreTrainedTokenizer]],
*,
overwrite: bool = False,
) -> None:
"""Register a model factory.
Args:
names (`str` or `tuple[str, ...]`):
Names of the model.
factory (`Callable[[str, torch.dtype], tuple[PreTrainedModel, PreTrainedTokenizer]]`):
Factory function.
overwrite (`bool`, *optional*, defaults to `False`):
Whether to overwrite the existing factory for the model.
"""
if isinstance(names, str):
names = (names,)
for name in names:
if not overwrite and name in cls._model_factories:
raise ValueError(f"Factory for {name} already exists")
cls._model_factories[name] = factory
================================================
FILE: deepcompressor/app/llm/nn/__init__.py
================================================
# -*- coding: utf-8 -*-
from .struct import LlmModelStruct, LlmTransformerBlockStruct, LlmTransformerStruct
================================================
FILE: deepcompressor/app/llm/nn/patch.py
================================================
# -*- coding: utf-8 -*-
"""Llama model patcher."""
import functools
import torch
import torch.nn as nn
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
from deepcompressor.utils import tools
from deepcompressor.utils.patch import copy_func
__all__ = ["patch_attention", "patch_gemma_rms_norm", "RotaryEmbedding"]
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def update_rotary_cos_sin(
cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.LongTensor | None, unsqueeze_dim: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
"""Update the cos and sin tensors with new position_ids.
Args:
cos (``torch.Tensor``):
Cosine tensor.
sin (``torch.Tensor``):
Sine tensor.
position_ids (``torch.LongTensor | None``):
Position ids.
unsqueeze_dim (``int``, *optional*, defaults to ``1``):
The dimension along which to unsqueeze cos and sin.
Returns:
``tuple[torch.Tensor]``:
Updated cos and sin tensors.
"""
assert unsqueeze_dim in (1, 2), f"unsqueeze_dim must be 1 or 2, got {unsqueeze_dim}"
if position_ids is None:
if cos.ndim == 2:
cos = cos.unsqueeze(0)
if sin.ndim == 2:
sin = sin.unsqueeze(0)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
else:
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] if unsqueeze_dim == 1
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, seq_len, 1, dim] if unsqueeze_dim == 2
assert cos.ndim == 4, f"cos must have 4 dimensions, got {cos.ndim}"
assert sin.ndim == 4, f"sin must have 4 dimensions, got {sin.ndim}"
return cos, sin
class RotaryEmbedding(nn.Module):
"""Rotary embedding for attention."""
def __init__(self) -> None:
"""Initialize the class."""
super().__init__()
def forward(
self, states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> torch.Tensor:
"""Apply rotary embedding to states.
Args:
states (torch.Tensor): States.
cos (torch.Tensor): Cosine tensor.
sin (torch.Tensor): Sine tensor.
unsqueeze_dim (int, optional): The dimension along which to unsqueeze cos and sin.
Defaults to ``1``.
Returns:
torch.Tensor: States with rotary embedding.
"""
states = (states * cos) + (rotate_half(states) * sin)
if unsqueeze_dim == 1:
batch_size, num_heads, seq_len, head_dim = states.shape
states = states.transpose(1, 2)
else:
batch_size, seq_len, num_heads, head_dim = states.shape
return states.view(batch_size, seq_len, num_heads * head_dim)
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: torch.LongTensor = None,
unsqueeze_dim: int = 1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`):
The query tensor.
k (`torch.Tensor`):
The key tensor.
cos (`torch.Tensor`):
The cosine part of the rotary embedding.
sin (`torch.Tensor`):
The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple[torch.Tensor, torch.Tensor]`:
comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
assert unsqueeze_dim == 1 or unsqueeze_dim == 2, f"unsqueeze_dim must be 1 or 2, got {unsqueeze_dim}"
if unsqueeze_dim == 1:
batch_size, _, seq_len, head_dim = q.shape
else:
batch_size, seq_len, _, head_dim = q.shape
cos, sin = update_rotary_cos_sin(cos, sin, position_ids=position_ids, unsqueeze_dim=unsqueeze_dim)
q = self.q_rotary_emb(q, cos=cos, sin=sin, unsqueeze_dim=unsqueeze_dim)
k = self.k_rotary_emb(k, cos=cos, sin=sin, unsqueeze_dim=unsqueeze_dim)
q = q.view(batch_size, seq_len, -1, head_dim)
k = k.view(batch_size, seq_len, -1, head_dim)
if unsqueeze_dim == 1:
q = q.transpose(1, 2)
k = k.transpose(1, 2)
return q, k
def patch_attention(model: nn.Module) -> nn.Module:
"""Patch attention."""
logger = tools.logging.getLogger(f"{__name__}.ModelPatcher")
for module_name, module in model.named_modules():
classname = type(module).__name__
if classname.lower().endswith("attention"):
forward_name = ""
if isinstance(module.forward, functools.partial):
if hasattr(module, "_deepcompressor_orig_forward"):
logger.info(f"- Attention in {module_name} has already been patched")
else:
# this module has been wrapped in ``accelerate`` package
assert hasattr(module, "_old_forward")
assert module._old_forward is module.forward.__wrapped__
if "apply_rotary_pos_emb" in module._old_forward.__func__.__globals__:
forward_name = "_old_forward"
else:
if "apply_rotary_pos_emb" in module.forward.__func__.__globals__:
forward_name = "forward"
if forward_name:
logger.info(f"- Patching {classname}.{forward_name} in {module_name}")
module.q_rotary_emb = RotaryEmbedding()
module.k_rotary_emb = RotaryEmbedding()
module.apply_rotary_pos_emb = functools.partial(apply_rotary_pos_emb, module)
module._deepcompressor_orig_forward = getattr(module, forward_name)
orig_forward = module._deepcompressor_orig_forward.__func__
new_globals = dict(orig_forward.__globals__)
new_globals["apply_rotary_pos_emb"] = module.apply_rotary_pos_emb
new_forward = copy_func(orig_forward, new_globals)
setattr(module, forward_name, new_forward.__get__(module))
return model
def gemma_rms_norm_forward(self: GemmaRMSNorm | Gemma2RMSNorm, x: torch.Tensor) -> torch.Tensor:
"""Forward function for Gemma RMSNorm."""
assert hasattr(self, "_deepcompressor_orig_forward"), "Gemma RMSNorm must be patched before calling forward"
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * self.weight.float()
return output.type_as(x)
def patch_gemma_rms_norm(model: nn.Module) -> nn.Module:
"""Patch Gemma RMSNorm."""
logger = tools.logging.getLogger(f"{__name__}.ModelPatcher")
for module_name, module in model.named_modules():
if isinstance(module, (GemmaRMSNorm, Gemma2RMSNorm)):
classname = type(module).__name__
forward_name = ""
if hasattr(module, "_deepcompressor_orig_forward"):
logger.info(f"- {module_name} has already been patched")
else:
if isinstance(module.forward, functools.partial):
assert hasattr(module, "_old_forward")
assert module._old_forward is module.forward.__wrapped__
forward_name = "_old_forward"
else:
forward_name = "forward"
if forward_name:
logger.info(f"- Patching {classname}.{forward_name} in {module_name}")
module.weight.data.add_(1.0)
module._deepcompressor_orig_forward = getattr(module, forward_name)
setattr(module, forward_name, functools.partial(gemma_rms_norm_forward, module))
return model
================================================
FILE: deepcompressor/app/llm/nn/struct.py
================================================
# -*- coding: utf-8 -*-
"""Utility functions for Large Language Models."""
# region imports
import typing as tp
from dataclasses import dataclass, field
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2Config,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2MLP,
Gemma2Model,
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaMLP,
LlamaModel,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralConfig,
MistralDecoderLayer,
MistralForCausalLM,
MistralForSequenceClassification,
MistralMLP,
MistralModel,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralConfig,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralForSequenceClassification,
MixtralModel,
MixtralSparseMoeBlock,
)
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2Config,
Qwen2DecoderLayer,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen2MLP,
Qwen2Model,
)
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5Block,
T5Config,
T5DenseActDense,
T5DenseGatedActDense,
T5EncoderModel,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)
from deepcompressor.nn.struct.attn import (
AttentionConfigStruct,
BaseTransformerStruct,
FeedForwardConfigStruct,
FeedForwardStruct,
SelfAttentionStruct,
TransformerBlockStruct,
)
from deepcompressor.nn.struct.base import BaseModuleStruct
from deepcompressor.utils.common import join_name
from .patch import RotaryEmbedding
# endregion
__all__ = [
"LlmConfigStruct",
"LlmModelStruct",
"LlmTransformerStruct",
"LlmTransformerBlockStruct",
"LlmSelfAttentionStruct",
"LlmFeedForwardStruct",
]
# region type aliases
ATTENTION_CLS = tp.Union[
LlamaAttention, MistralAttention, MixtralAttention, Qwen2Attention, T5Attention, Gemma2Attention
]
FEEDFORWARD_CLS = tp.Union[
LlamaMLP, MistralMLP, MixtralSparseMoeBlock, Qwen2MLP, T5DenseActDense, T5DenseGatedActDense, Gemma2MLP
]
TRANSFORMER_BLOCK_CLS = tp.Union[
LlamaDecoderLayer, MistralDecoderLayer, MixtralDecoderLayer, Qwen2DecoderLayer, T5Block, Gemma2DecoderLayer
]
TRANSFORMER_CLS = tp.Union[LlamaModel, MistralModel, MixtralModel, Qwen2Model, T5Stack, Gemma2Model]
CASUALLM_CLS = tp.Union[LlamaForCausalLM, MistralForCausalLM, MixtralForCausalLM, Qwen2ForCausalLM, Gemma2ForCausalLM]
SEQCLSLM_CLS = tp.Union[
LlamaForSequenceClassification,
MistralForSequenceClassification,
MixtralForSequenceClassification,
Qwen2ForSequenceClassification,
Gemma2ForSequenceClassification,
]
# endregion
@dataclass(kw_only=True)
class LlmTransformerBlockConfigStruct(FeedForwardConfigStruct, AttentionConfigStruct):
"""Large Language Model Transformer Block Configuration.
Args:
hidden_size (`int`):
The size of the input/output activations, i.e., the number of input channels.
inner_size (`int`):
The size of the inner activations, i.e., the number of **query** channels in the attention block.
intermediate_size (`int`):
The number of intermediate channels in the feedforward network.
intermediate_act_type (`str`):
The activation function for the intermediate activations in the feedforward network.
num_query_heads (`int`):
The number of query heads.
num_key_value_heads (`int`):
The number of key-value heads.
num_experts (`int`):
The number of experts (for the feedforward network).
with_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to apply normalization to queries and keys.
with_rope (`bool`):
Whether to use Rotary Positional Encoding (RoPE).
Attributes:
head_size (`int`):
The size of the head, equal to `num_query_channels // num_query_heads`.
num_key_value_groups (`int`):
The number of key-value groups, equal to `num_query_heads // num_key_value_heads`.
intermediate_lowerbound (`float` or `None`):
The lowerbound of the intermediate activations in feedforward network.
"""
pass
@dataclass(kw_only=True)
class LlmTransformerConfigStruct(LlmTransformerBlockConfigStruct):
"""Large Language Model Transformer Configuration.
Args:
hidden_size (`int`):
The size of the input/output activations, i.e., the number of input channels.
inner_size (`int`):
The size of the inner activations, i.e., the number of **query** channels in the attention block.
intermediate_size (`int`):
The number of intermediate channels in the feedforward network.
intermediate_act_type (`str`):
The activation function for the intermediate activations in the feedforward network.
num_query_heads (`int`):
The number of query heads.
num_key_value_heads (`int`):
The number of key-value heads.
num_experts (`int`):
The number of experts (for the feedforward network).
with_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to apply normalization to queries and keys.
with_rope (`bool`):
Whether to use Rotary Positional Encoding (RoPE).
vocab_size (`int`):
The size of the vocabulary.
num_hidden_layers (`int`):
The number of hidden layers.
Attributes:
head_size (`int`):
The size of the head, equal to `num_query_channels // num_query_heads`.
num_key_value_groups (`int`):
The number of key-value groups, equal to `num_query_heads // num_key_value_heads`.
intermediate_lowerbound (`float` or `None`):
The lowerbound of the intermediate activations in feedforward network.
"""
vocab_size: int
num_hidden_layers: int
@dataclass(kw_only=True)
class LlmConfigStruct(LlmTransformerConfigStruct):
"""Large Language Model Configuration.
Args:
hidden_size (`int`):
The size of the input/output activations, i.e., the number of input channels.
inner_size (`int`):
The size of the inner activations, i.e., the number of **query** channels in the attention block.
intermediate_size (`int`):
The number of intermediate channels in the feedforward network.
intermediate_act_type (`str`):
The activation function for the intermediate activations in the feedforward network.
num_query_heads (`int`):
The number of query heads.
num_key_value_heads (`int`):
The number of key-value heads.
num_experts (`int`):
The number of experts (for the feedforward network).
with_qk_norm (`bool`, *optional*, defaults to `False`):
Whether to apply normalization to queries and keys.
with_rope (`bool`):
Whether to use Rotary Positional Encoding (RoPE).
vocab_size (`int`):
The size of the vocabulary.
num_hidden_layers (`int`):
The number of hidden layers.
tie_word_embeddings (`bool`):
Whether to tie the word embeddings with the head weights.
Attributes:
head_size (`int`):
The size of the head, equal to `num_query_channels // num_query_heads`.
num_key_value_groups (`int`):
The number of key-value groups, equal to `num_query_heads // num_key_value_heads`.
intermediate_lowerbound (`float` or `None`):
The lowerbound of the intermediate activations in feedforward network.
"""
tie_word_embeddings: bool = False
@dataclass(kw_only=True)
class LlmSelfAttentionStruct(SelfAttentionStruct):
"""Large Language Model Attention Block."""
# region relative keys
q_rkey: tp.ClassVar[str] = "attn_q"
k_rkey: tp.ClassVar[str] = "attn_k"
v_rkey: tp.ClassVar[str] = "attn_v"
# endregion
parent: tp.Optional["LlmTransformerBlockStruct"] = field(repr=False)
kwargs: tuple[str, ...]
def filter_kwargs(self, kwargs: dict) -> dict:
"""Filter layer kwargs to attn kwargs."""
return {k: v for k, v in kwargs.items() if k in self.kwargs}
@staticmethod
def _default_construct(
module: ATTENTION_CLS,
/,
parent: tp.Optional["LlmTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "LlmSelfAttentionStruct":
if isinstance(module, T5Attention):
with_rope, num_query_heads, num_key_value_heads = False, module.n_heads, module.n_heads
q_proj, k_proj, v_proj, o_proj = module.q, module.k, module.v, module.o
q_proj_rname, k_proj_rname, v_proj_rname, o_proj_rname = "q", "k", "v", "o"
q, k, v = module.q, module.k, module.v
q_rname, k_rname, v_rname = "q", "k", "v"
kwargs = (
"mask",
"key_value_states",
"position_bias",
"past_key_value",
"layer_head_mask",
"query_length",
"use_cache",
"output_attentions",
)
elif isinstance(module, (LlamaAttention, MistralAttention, MixtralAttention, Qwen2Attention, Gemma2Attention)):
with_rope = True
num_query_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
q_proj, k_proj, v_proj, o_proj = module.q_proj, module.k_proj, module.v_proj, module.o_proj
q_proj_rname, k_proj_rname, v_proj_rname, o_proj_rname = "q_proj", "k_proj", "v_proj", "o_proj"
if hasattr(module, "q_rotary_emb"):
q, k = module.q_rotary_emb, module.k_rotary_emb
q_rname, k_rname = "q_rotary_emb", "k_rotary_emb"
assert isinstance(q, RotaryEmbedding)
assert isinstance(k, RotaryEmbedding)
else:
q, k = module.q_proj, module.k_proj
q_rname, k_rname = "q_proj", "k_proj"
v, v_rname = module.v_proj, "v_proj"
kwargs = (
"attention_mask",
"position_ids",
"past_key_value",
"output_attentions",
"use_cache",
"position_embeddings",
"cache_position",
)
else:
raise ValueError(f"Unsupported attention type: {type(module)}")
config = AttentionConfigStruct(
hidden_size=q_proj.weight.shape[1],
inner_size=q_proj.weight.shape[0],
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
with_qk_norm=False,
with_rope=with_rope,
)
if parent is not None and parent.config is not None:
assert parent.config.hidden_size == config.hidden_size
assert parent.config.inner_size == config.inner_size
assert parent.config.num_query_heads == config.num_query_heads
assert parent.config.num_key_value_heads == config.num_key_value_heads
assert parent.config.with_qk_norm == config.with_qk_norm
assert parent.config.with_rope == config.with_rope
return LlmSelfAttentionStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
o_proj=o_proj,
q=q,
k=k,
v=v,
q_proj_rname=q_proj_rname,
k_proj_rname=k_proj_rname,
v_proj_rname=v_proj_rname,
o_proj_rname=o_proj_rname,
q_rname=q_rname,
k_rname=k_rname,
v_rname=v_rname,
kwargs=kwargs,
)
@dataclass(kw_only=True)
class LlmFeedForwardStruct(FeedForwardStruct):
"""Large Language Model Feedforward Network."""
parent: tp.Optional["LlmTransformerBlockStruct"] = field(repr=False)
@staticmethod
def _default_construct(
module: FEEDFORWARD_CLS,
/,
parent: tp.Optional["LlmTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "LlmFeedForwardStruct":
if isinstance(module, (LlamaMLP, MistralMLP, Qwen2MLP, Gemma2MLP)):
if parent is not None:
assert parent.config.intermediate_act_type.endswith("_glu")
act_type = parent.config.intermediate_act_type
else:
act_type = str(module.act_fn.__class__.__name__).removesuffix("activation").lower() + "_glu"
up_projs, down_projs = [module.up_proj, module.gate_proj], [module.down_proj]
experts = [module]
moe_gate = None
up_proj_rnames = ["up_proj", "gate_proj"]
down_proj_rnames = ["down_proj"]
experts_rname = ""
moe_gate_rname = ""
elif isinstance(module, MixtralSparseMoeBlock):
if parent is not None:
assert parent.config.intermediate_act_type.endswith("_glu")
act_type = parent.config.intermediate_act_type
else:
act_type = str(module.experts[0].act_fn.__class__.__name__).removesuffix("activation").lower() + "_glu"
up_projs = [expert.w3 for expert in module.experts] + [expert.w1 for expert in module.experts]
down_projs = [expert.w2 for expert in module.experts]
experts = list(module.experts)
moe_gate = module.gate
up_proj_rnames = ["w3", "w1"]
down_proj_rnames = ["w2"]
experts_rname = "experts"
moe_gate_rname = "gate"
elif isinstance(module, T5DenseActDense):
if parent is not None:
assert not parent.config.intermediate_act_type.endswith("_glu")
act_type = parent.config.intermediate_act_type
else:
act_type = str(module.act.__class__.__name__).removesuffix("activation").lower()
up_projs = [module.wi]
down_projs = [module.wo]
experts = [module]
moe_gate = None
up_proj_rnames = ["wi"]
down_proj_rnames = ["wo"]
experts_rname = ""
moe_gate_rname = ""
elif isinstance(module, T5DenseGatedActDense):
if parent is not None:
assert parent.config.intermediate_act_type.endswith("_glu")
act_type = parent.config.intermediate_act_type
else:
act_type = str(module.act.__class__.__name__).removesuffix("activation").lower() + "_glu"
up_projs = [module.wi_1, module.wi_0]
down_projs = [module.wo]
experts = [module]
moe_gate = None
up_proj_rnames = ["wi_1", "wi_0"]
down_proj_rnames = ["wo"]
experts_rname = ""
moe_gate_rname = ""
else:
raise ValueError(f"Unsupported feed forward network type: {type(module)}")
config = FeedForwardConfigStruct(
hidden_size=up_projs[0].weight.shape[1],
intermediate_size=up_projs[0].weight.shape[0],
intermediate_act_type=act_type,
num_experts=len(experts),
)
if parent is not None and parent.config is not None:
assert parent.config.hidden_size == config.hidden_size
assert parent.config.intermediate_size == config.intermediate_size
assert parent.config.intermediate_act_type == config.intermediate_act_type
assert parent.config.num_experts == config.num_experts
return LlmFeedForwardStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
up_projs=up_projs,
down_projs=down_projs,
moe_gate=moe_gate,
experts=experts,
up_proj_rnames=up_proj_rnames,
down_proj_rnames=down_proj_rnames,
moe_gate_rname=moe_gate_rname,
experts_rname=experts_rname,
)
@dataclass(kw_only=True)
class LlmTransformerBlockStruct(TransformerBlockStruct):
"""Large Language Model Transformer Block."""
# region relative keys
attn_rkey: tp.ClassVar[str] = ""
ffn_rkey: tp.ClassVar[str] = ""
add_ffn_rkey: tp.ClassVar[str] = "add"
attn_struct_cls: tp.ClassVar[tp.Type[LlmSelfAttentionStruct]] = LlmSelfAttentionStruct
ffn_struct_cls: tp.ClassVar[tp.Type[LlmFeedForwardStruct]] = LlmFeedForwardStruct
# endregion
parent: tp.Optional["LlmTransformerStruct"] = field(repr=False)
parallel: bool = field(init=False, repr=False, default=False)
config: LlmTransformerBlockConfigStruct = field(default=None)
# region child modules
pre_attn_add_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
post_attn_add_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
pre_add_ffn_norm: None = field(init=False, repr=False, default=None)
add_ffn: None = field(init=False, repr=False, default=None)
post_add_ffn_norm: None = field(init=False, repr=False, default=None)
# endregion
# region relative names
pre_attn_add_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
post_attn_add_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
pre_add_ffn_norm_rname: str = field(init=False, repr=False, default="")
add_ffn_rname: str = field(init=False, repr=False, default="")
post_add_ffn_norm_rname: str = field(init=False, repr=False, default="")
# endregion
# region child structs
attn_structs: list[LlmSelfAttentionStruct] = field(init=False, repr=False)
ffn_struct: LlmFeedForwardStruct = field(init=False, repr=False)
add_ffn_struct: None = field(init=False, repr=False, default=None)
# endregion
# region aliases
@property
def pre_attn_norm(self) -> nn.LayerNorm | None:
return self.pre_attn_norms[0] if self.pre_attn_norms else None
@property
def attn(self) -> nn.Module:
return self.attns[0]
@property
def post_attn_norm(self) -> nn.LayerNorm | None:
return self.post_attn_norms[0] if self.post_attn_norms else None
@property
def pre_attn_norm_rname(self) -> str:
return self.pre_attn_norm_rnames[0] if self.pre_attn_norm_rnames else ""
@property
def attn_rname(self) -> str:
return self.attn_rnames[0]
@property
def post_attn_norm_rname(self) -> str:
return self.post_attn_norm_rnames[0] if self.post_attn_norm_rnames else ""
@property
def pre_attn_norm_name(self) -> str:
return self.pre_attn_norm_names[0] if self.pre_attn_norm_names else ""
@property
def attn_name(self) -> str:
return self.attn_names[0]
@property
def post_attn_norm_name(self) -> str:
return self.post_attn_norm_names[0] if self.post_attn_norm_names else ""
@property
def attn_struct(self) -> LlmSelfAttentionStruct:
return self.attn_structs[0]
# endregion
def __post_init__(self):
super().__post_init__()
assert len(self.attn_structs) == 1
if self.config is None:
self.config = LlmTransformerBlockConfigStruct(
hidden_size=self.attn_struct.config.hidden_size,
inner_size=self.attn_struct.config.inner_size,
num_query_heads=self.attn_struct.config.num_query_heads,
num_key_value_heads=self.attn_struct.config.num_key_value_heads,
with_qk_norm=self.attn_struct.config.with_qk_norm,
with_rope=self.attn_struct.config.with_rope,
intermediate_size=self.ffn_struct.config.intermediate_size,
intermediate_act_type=self.ffn_struct.config.intermediate_act_type,
num_experts=self.ffn_struct.config.num_experts,
)
@staticmethod
def _default_construct(
module: TRANSFORMER_BLOCK_CLS,
/,
parent: tp.Optional["LlmTransformerStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "LlmTransformerBlockStruct":
if isinstance(
module, (LlamaDecoderLayer, MistralDecoderLayer, Qwen2DecoderLayer, MixtralDecoderLayer, Gemma2DecoderLayer)
):
pre_attn_norms, attns = [module.input_layernorm], [module.self_attn]
pre_attn_norm_rnames, attn_rnames = ["input_layernorm"], ["self_attn"]
if isinstance(module, Gemma2DecoderLayer):
post_attn_norms, post_attn_norm_rnames = [module.post_attention_layernorm], ["post_attention_layernorm"]
pre_ffn_norm, pre_ffn_norm_rname = (module.pre_feedforward_layernorm, "pre_feedforward_layernorm")
post_ffn_norm, post_ffn_norm_rname = module.post_feedforward_layernorm, "post_feedforward_layernorm"
else:
post_attn_norms, post_attn_norm_rnames = [], []
pre_ffn_norm, pre_ffn_norm_rname = module.post_attention_layernorm, "post_attention_layernorm"
post_ffn_norm, post_ffn_norm_rname = None, ""
if isinstance(module, MixtralDecoderLayer):
ffn, ffn_rname = module.block_sparse_moe, "block_sparse_moe"
else:
ffn, ffn_rname = module.mlp, "mlp"
elif isinstance(module, T5Block):
pre_attn_norms, attns, pre_attn_norm_rnames, attn_rnames = [], [], [], []
post_attn_norms, post_attn_norm_rnames = [], []
post_ffn_norm, post_ffn_norm_rname = None, ""
for i, layer in enumerate(module.layer):
if isinstance(layer, T5LayerSelfAttention):
pre_attn_norms.append(layer.layer_norm)
attns.append(layer.SelfAttention)
pre_attn_norm_rnames.append(f"layer.{i}.layer_norm")
attn_rnames.append(f"layer.{i}.SelfAttention")
else:
assert isinstance(layer, T5LayerFF)
pre_ffn_norm, ffn = layer.layer_norm, layer.DenseReluDense
pre_ffn_norm_rname, ffn_rname = f"layer.{i}.layer_norm", f"layer.{i}.DenseReluDense"
else:
raise ValueError(f"Unsupported layer type: {type(module)}")
config = parent.config if parent is not None and parent.config is not None else None
return LlmTransformerBlockStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
pre_attn_norms=pre_attn_norms,
attns=attns,
post_attn_norms=post_attn_norms,
pre_ffn_norm=pre_ffn_norm,
ffn=ffn,
post_ffn_norm=post_ffn_norm,
pre_attn_norm_rnames=pre_attn_norm_rnames,
attn_rnames=attn_rnames,
post_attn_norm_rnames=post_attn_norm_rnames,
pre_ffn_norm_rname=pre_ffn_norm_rname,
ffn_rname=ffn_rname,
post_ffn_norm_rname=post_ffn_norm_rname,
)
@dataclass(kw_only=True)
class LlmTransformerStruct(BaseTransformerStruct):
"""Large Language Model Structure."""
# region relative keys
layer_rkey: tp.ClassVar[str] = ""
layer_struct_cls: tp.ClassVar[tp.Type[LlmTransformerBlockStruct]] = LlmTransformerBlockStruct
# endregion
parent: tp.Optional["LlmModelStruct"] = field(repr=False)
config: LlmTransformerConfigStruct = field(default=None)
# region child modules
# embeddings: list[nn.Embedding]
# """list of embeddings [embed_tokens, embed_positions]"""
embed_tokens: nn.Embedding
"""Token embedding module."""
embed_positions: nn.Embedding | None
"""Position embedding module."""
layers: nn.ModuleList
# endregion
# region relative names
embed_tokens_rname: str
embed_positions_rname: str
layers_rname: str
# endregion
# region absolute names
embed_tokens_name: str = field(init=False, repr=False)
embed_positions_name: str = field(init=False, repr=False)
layers_name: str = field(init=False, repr=False)
layer_names: list[str] = field(init=False, repr=False)
# endregion
# region child structs
layer_structs: list[LlmTransformerBlockStruct] = field(init=False, repr=False)
# endregion
# region abstractmethod implementations
@property
def num_blocks(self) -> int:
"""Get the number of transformer blocks."""
return len(self.layers)
@property
def block_structs(self) -> list[LlmTransformerBlockStruct]:
return self.layer_structs
@property
def block_names(self) -> list[str]:
"""Get the list of transformer block names."""
return self.layer_names
# endregion
def __post_init__(self) -> None:
super().__post_init__()
self.embed_tokens_name = join_name(self.name, self.embed_tokens_rname)
if self.embed_positions is not None:
self.embed_positions_name = join_name(self.name, self.embed_positions_rname)
else:
self.embed_positions_name = ""
self.layers_name = join_name(self.name, self.layers_rname)
layer_rnames = [f"{self.layers_rname}.{idx}" for idx in range(len(self.layers))]
self.layer_names = [join_name(self.name, rname) for rname in layer_rnames]
self.layer_structs = [
self.layer_struct_cls.construct(
layer, parent=self, fname="layer", rname=rname, rkey=self.layer_rkey, idx=idx
)
for idx, (layer, rname) in enumerate(zip(self.layers, layer_rnames, strict=True))
]
if self.config is None:
assert all(block.config == self.block_structs[0].config for block in self.block_structs)
ref_config = self.block_structs[0].config
self.config = LlmTransformerConfigStruct(
hidden_size=ref_config.hidden_size,
inner_size=ref_config.inner_size,
num_query_heads=ref_config.num_query_heads,
num_key_value_heads=ref_config.num_key_value_heads,
with_qk_norm=ref_config.with_qk_norm,
with_rope=ref_config.with_rope,
intermediate_size=ref_config.intermediate_size,
intermediate_act_type=ref_config.intermediate_act_type,
num_experts=ref_config.num_experts,
vocab_size=self.embed_tokens.num_embeddings,
num_hidden_layers=self.num_blocks,
)
else:
assert self.config.vocab_size == self.embed_tokens.num_embeddings
assert self.config.num_hidden_layers == self.num_blocks
def get_iter_layer_activations_args(
self, **kwargs
) -> tuple[list[nn.Module], list[LlmTransformerBlockStruct], list[bool], list[bool]]:
"""
Get the arguments for iterating over the layers and their activations.
Args:
skip_pre_modules (`bool`):
Whether to skip the pre-modules
skip_post_modules (`bool`):
Whether to skip the post-modules
Returns:
`tuple[list[nn.Module], list[LlmTransformerBlockStruct], list[bool], list[bool]]`:
the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
"""
return self.layers, self.layer_structs, [False] * len(self.layers), [True] * len(self.layers)
@staticmethod
def _default_construct(
module: TRANSFORMER_CLS,
/,
parent: tp.Optional["LlmModelStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "LlmTransformerStruct":
if isinstance(module, (LlamaModel, MistralModel, MixtralModel, Qwen2Model, Gemma2Model)):
embed_tokens, embed_positions = module.embed_tokens, None
layers = module.layers
norm_in, norm_out = None, module.norm
proj_in, proj_out = None, None
embed_tokens_rname, embed_positions_rname = "embed_tokens", ""
layers_rname = "layers"
norm_in_rname, norm_out_rname = "", "norm"
proj_in_rname, proj_out_rname = "", ""
elif isinstance(module, T5Stack):
embed_tokens, embed_positions = module.embed_tokens, None
layers = module.block
norm_in, norm_out = None, module.final_layer_norm
proj_in, proj_out = None, None
embed_tokens_rname, embed_positions_rname = "embed_tokens", ""
layers_rname = "block"
norm_in_rname, norm_out_rname = "", "final_layer_norm"
proj_in_rname, proj_out_rname = "", ""
else:
raise ValueError(f"Unsupported backbone type: {type(module)}")
config = parent.config if parent is not None and parent.config is not None else None
return LlmTransformerStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
embed_tokens=embed_tokens,
embed_positions=embed_positions,
norm_in=norm_in,
proj_in=proj_in,
layers=layers,
norm_out=norm_out,
proj_out=proj_out,
embed_tokens_rname=embed_tokens_rname,
embed_positions_rname=embed_positions_rname,
norm_in_rname=norm_in_rname,
proj_in_rname=proj_in_rname,
layers_rname=layers_rname,
norm_out_rname=norm_out_rname,
proj_out_rname=proj_out_rname,
)
@dataclass(kw_only=True)
class LlmModelStruct(BaseModuleStruct):
"""Large Language Model Structure."""
# region relative keys
backbone_rkey: tp.ClassVar[str] = ""
head_rkey: tp.ClassVar[str] = "head"
backbone_struct_cls: tp.ClassVar[tp.Type[LlmTransformerStruct]] = LlmTransformerStruct
# endregion
module: PreTrainedModel = field(repr=False, kw_only=False)
config: LlmConfigStruct
# region child modules
backbone: nn.Module
head: nn.Linear | None
# endregion
# region relative names
backbone_rname: str
head_rname: str
# endregion
# region absolute names
backbone_name: str = field(init=False, repr=False)
head_name: str = field(init=False, repr=False)
# endregion
# region absolute keys
head_key: str = field(init=False, repr=False)
# endregion
# region child structs
backbone_struct: LlmTransformerStruct = field(init=False, repr=False)
# endregion
def __post_init__(self) -> None:
super().__post_init__()
self.backbone_name = join_name(self.name, self.backbone_rname)
if self.head is not None or self.head_rname:
self.head_name = join_name(self.name, self.head_rname)
else:
self.head_name = self.head_rname = ""
self.head_key = join_name(self.key, self.head_rkey, sep="_")
self.backbone_struct = self.backbone_struct_cls.construct(
self.backbone, parent=self, fname="backbone", rname=self.backbone_rname, rkey=self.backbone_rkey
)
def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
yield from self.backbone_struct.named_key_modules()
if self.head is not None:
yield self.head_key, self.head_name, self.head, self, "head"
def iter_attention_structs(self) -> tp.Generator[LlmSelfAttentionStruct, None, None]:
yield from self.backbone_struct.iter_attention_structs()
def iter_transformer_block_structs(self) -> tp.Generator[LlmTransformerBlockStruct, None, None]:
yield from self.backbone_struct.iter_transformer_block_structs()
def get_iter_layer_activations_args(
self, **kwargs
) -> tuple[list[nn.Module], list[LlmTransformerBlockStruct], list[bool], list[bool]]:
"""
Get the arguments for iterating over the layers and their activations.
Args:
skip_pre_modules (`bool`):
Whether to skip the pre-modules
skip_post_modules (`bool`):
Whether to skip the post-modules
Returns:
`tuple[list[nn.Module], list[LlmTransformerBlockStruct], list[bool], list[bool]]`:
the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
"""
return self.backbone_struct.get_iter_layer_activations_args(**kwargs)
@staticmethod
def _default_construct(
model: nn.Module,
/,
parent: tp.Optional[BaseModuleStruct] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "LlmModelStruct":
"""Build the Large Language Model Structure."""
if isinstance(model, CASUALLM_CLS) or isinstance(model, SEQCLSLM_CLS):
backbone = model.model
backbone_rname = "model"
elif isinstance(model, T5EncoderModel):
backbone = model.encoder
backbone_rname = "encoder"
elif isinstance(model, TRANSFORMER_CLS):
backbone = model
backbone_rname = ""
else:
raise ValueError(f"Unsupported model type: {type(model)}")
if isinstance(model, CASUALLM_CLS):
head = model.lm_head
head_rname = "lm_head"
elif isinstance(model, SEQCLSLM_CLS):
head = model.score
head_rname = "score"
elif isinstance(model, T5EncoderModel):
head = None
head_rname = ""
elif isinstance(model, TRANSFORMER_CLS):
head = None
head_rname = ""
else:
raise ValueError(f"Unsupported model type: {type(model)}")
config = backbone.config
if isinstance(config, (LlamaConfig, MistralConfig, MixtralConfig, Qwen2Config, Gemma2Config)):
config_struct = LlmConfigStruct(
hidden_size=config.hidden_size,
inner_size=config.num_attention_heads * config.head_dim
if isinstance(config, Gemma2Config)
else config.hidden_size,
num_query_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
with_qk_norm=False,
with_rope=True,
intermediate_size=config.intermediate_size,
intermediate_act_type=f"{config.hidden_act}_glu".lower(),
num_experts=getattr(config, "num_local_experts", 1),
vocab_size=config.vocab_size,
num_hidden_layers=config.num_hidden_layers,
tie_word_embeddings=config.tie_word_embeddings,
)
elif isinstance(config, T5Config):
config_struct = LlmConfigStruct(
hidden_size=config.d_model,
inner_size=config.d_kv * config.num_heads,
num_query_heads=config.num_heads,
num_key_value_heads=config.num_heads,
with_rope=False,
intermediate_size=config.d_ff,
intermediate_act_type=config.dense_act_fn.lower(),
num_experts=1,
vocab_size=config.vocab_size,
num_hidden_layers=config.num_layers,
tie_word_embeddings=False,
)
if config.is_gated_act:
config_struct.intermediate_act_type += "_glu"
else:
raise ValueError(f"Unsupported config type: {type(config)}")
return LlmModelStruct(
module=model,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config_struct,
backbone=backbone,
head=head,
backbone_rname=backbone_rname,
head_rname=head_rname,
)
LlmSelfAttentionStruct.register_factory(ATTENTION_CLS, LlmSelfAttentionStruct._default_construct)
LlmFeedForwardStruct.register_factory(FEEDFORWARD_CLS, LlmFeedForwardStruct._default_construct)
LlmTransformerBlockStruct.register_factory(TRANSFORMER_BLOCK_CLS, LlmTransformerBlockStruct._default_construct)
LlmTransformerStruct.register_factory(TRANSFORMER_CLS, LlmTransformerStruct._default_construct)
LlmModelStruct.register_factory(
tp.Union[TRANSFORMER_CLS, CASUALLM_CLS, SEQCLSLM_CLS, T5EncoderModel], LlmModelStruct._default_construct
)
================================================
FILE: deepcompressor/app/llm/ptq.py
================================================
# -*- coding: utf-8 -*-
"""Evaluate a large language model."""
import gc
import json
import os
import pprint
import traceback
import torch
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizer
from deepcompressor.utils import tools
from .config import LlmCacheConfig, LlmPtqRunConfig, LlmQuantCacheConfig, LlmQuantConfig
from .nn import LlmModelStruct
from .quant import quantize_llm_activations, quantize_llm_weights, reorder_llm, rotate_llm, smooth_llm
__all__ = ["ptq"]
def ptq( # noqa: C901
model: PreTrainedModel | LlmModelStruct,
/,
tokenizer: PreTrainedTokenizer,
config: LlmQuantConfig,
cache: LlmCacheConfig | None = None,
load_dirpath: str = "",
save_dirpath: str = "",
copy_on_save: bool = False,
save_model: bool = False,
) -> PreTrainedModel:
"""Post-training quantization of a large language model.
Args:
model (`PreTrainedModel` or `LlmStruct`):
The large language model.
tokenizer (`PreTrainedTokenizer`):
The large language model tokenizer.
config (`LlmQuantConfig`):
The large language model post-training quantization configuration.
cache (`LlmCacheConfig`, *optional*, defaults to `None`):
The large language model quantization cache path configuration.
load_dirpath (`str`, *optional*, defaults to `""`):
The directory path to load the quantization checkpoint.
save_dirpath (`str`, *optional*, defaults to `""`):
The directory path to save the quantization checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the cache to the save directory.
save_model (`bool`, *optional*, defaults to `False`):
Whether to save the quantized model checkpoint.
Returns:
`PreTrainedModel`:
The quantized model.
"""
logger = tools.logging.getLogger(__name__)
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
quant_wgts = config.enabled_wgts
quant_ipts = config.enabled_ipts
quant_opts = config.enabled_opts
quant_acts = quant_ipts or quant_opts
quant = quant_wgts or quant_acts
needs_rotation = quant and config.enabled_rotation
needs_reorder = quant and config.enabled_reorder
needs_smooth = quant and config.enabled_smooth
load_model_path, load_path, save_path = "", None, None
if load_dirpath:
load_path = LlmQuantCacheConfig(
rotation=os.path.join(load_dirpath, "rotation.pt"),
reorder=os.path.join(load_dirpath, "reorder.pt"),
smooth=os.path.join(load_dirpath, "smooth.pt"),
wgts=os.path.join(load_dirpath, "wgts.pt"),
acts=os.path.join(load_dirpath, "acts.pt"),
)
load_model_path = os.path.join(load_dirpath, "model.pt")
if os.path.exists(load_model_path):
logger.info(f"* Found the model from {load_model_path}")
load_model = True
save_dirpath = "" # do not save the model if loading
if needs_reorder and not config.reorder.dynamic:
needs_reorder = False
logger.info("* Safe to skip reordering the model")
if needs_smooth:
needs_smooth = False
logger.info("* Safe to skip smoothing the model")
else:
logger.warning(f"Model checkpoint {load_model_path} does not exist")
load_model, load_model_path = False, ""
else:
load_model = False
if save_dirpath:
os.makedirs(save_dirpath, exist_ok=True)
save_path = LlmQuantCacheConfig(
rotation=os.path.join(save_dirpath, "rotation.pt"),
reorder=os.path.join(save_dirpath, "reorder.pt"),
smooth=os.path.join(save_dirpath, "smooth.pt"),
wgts=os.path.join(save_dirpath, "wgts.pt"),
acts=os.path.join(save_dirpath, "acts.pt"),
)
else:
save_model = False
# region rotate model
if needs_rotation:
logger.info("* Rotating model")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.rotation):
load_from = load_path.rotation
elif cache and cache.path.rotation and os.path.exists(cache.path.rotation):
load_from = cache.path.rotation
elif os.path.exists(config.rotation.path):
load_from = config.rotation.path
if load_from:
logger.info(f"- Loading rotation from {load_from}")
rotation = torch.load(load_from).to(dtype=torch.float64)
rotate_llm(model, config.rotation, rotation=rotation)
else:
logger.info("- Generating rotation")
rotation = rotate_llm(model, config.rotation)
if cache and cache.path.rotation:
logger.info(f"- Saving rotation to {cache.path.rotation}")
os.makedirs(cache.dirpath.rotation, exist_ok=True)
torch.save(rotation, cache.path.rotation)
load_from = cache.path.rotation
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking rotation to {save_path.rotation}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.rotation)
else:
logger.info(f"- Saving rotation to {save_path.rotation}")
torch.save(rotation, save_path.rotation)
del rotation
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
logger.info(f"* Development dtype is {config.develop_dtype}")
# endregion
# region reorder channels
if needs_reorder:
logger.info("* Reordering channels")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.reorder):
load_from = load_path.reorder
elif cache and cache.path.reorder and os.path.exists(cache.path.reorder):
load_from = cache.path.reorder
if load_from:
logger.info(f"- Loading reorder indices from {load_from}")
reorder_cache = torch.load(load_from)
reorder_llm(model, config, tokenizer, reorder_cache=reorder_cache)
else:
logger.info("- Generating reorder indices")
reorder_cache = reorder_llm(model, config, tokenizer)
if cache and cache.path.reorder:
logger.info(f"- Saving reorder indices to {cache.path.reorder}")
os.makedirs(cache.dirpath.reorder, exist_ok=True)
torch.save(reorder_cache, cache.path.reorder)
load_from = cache.path.reorder
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking reorder indices to {save_path.reorder}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.reorder)
else:
logger.info(f"- Saving reorder indices to {save_path.reorder}")
torch.save(reorder_cache, save_path.reorder)
del reorder_cache
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# endregion
# region smooth quantization
if needs_smooth:
logger.info("* Smoothing model for quantization")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.smooth):
load_from = load_path.smooth
elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
load_from = cache.path.smooth
if load_from:
logger.info(f"- Loading smooth scales from {load_from}")
smooth_cache = torch.load(load_from)
smooth_llm(model, config, smooth_cache=smooth_cache)
else:
logger.info("- Generating smooth scales")
smooth_cache = smooth_llm(model, config, tokenizer=tokenizer)
if cache and cache.path.smooth:
logger.info(f"- Saving smooth scales to {cache.path.smooth}")
os.makedirs(cache.dirpath.smooth, exist_ok=True)
torch.save(smooth_cache, cache.path.smooth)
load_from = cache.path.smooth
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking smooth scales to {save_path.smooth}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
else:
logger.info(f"- Saving smooth scales to {save_path.smooth}")
torch.save(smooth_cache, save_path.smooth)
del smooth_cache
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# endregion
# region collect original state dict
if config.needs_acts_quantizer_cache:
if load_path and os.path.exists(load_path.acts):
orig_state_dict = None
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
orig_state_dict = None
else:
orig_state_dict: dict[str, torch.Tensor] = {
name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
}
else:
orig_state_dict = None
# endregion
if load_model:
logger.info(f"* Loading model checkpoint from {load_model_path}")
model.module.load_state_dict(torch.load(load_model_path))
gc.collect()
torch.cuda.empty_cache()
elif quant_wgts:
logger.info("* Quantizing weights")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.wgts):
load_from = load_path.wgts
elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
load_from = cache.path.wgts
if load_from:
logger.info(f"- Loading weight quantizer settings from {load_from}")
quantizer_state_dict = torch.load(load_from)
_, scale_state_dict = quantize_llm_weights(
model,
config,
tokenizer=tokenizer,
quantizer_state_dict=quantizer_state_dict,
return_with_scale_state_dict=save_model,
)
else:
logger.info("- Generating weight quantizer settings")
quantizer_state_dict, scale_state_dict = quantize_llm_weights(
model, config, tokenizer=tokenizer, return_with_scale_state_dict=save_model
)
if cache and cache.dirpath.wgts:
logger.info(f"- Saving weight quantizer settings to {cache.path.wgts}")
os.makedirs(cache.dirpath.wgts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.wgts)
load_from = cache.path.wgts
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking weight quantizer settings to {save_path.wgts}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.wgts)
else:
logger.info(f"- Saving weight quantizer settings to {save_path.wgts}")
torch.save(quantizer_state_dict, save_path.wgts)
if save_model:
logger.info(f"- Saving model checkpoint to {save_dirpath}")
torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
del quantizer_state_dict, scale_state_dict
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
if quant_acts:
logger.info(" * Quantizing activations")
tools.logging.Formatter.indent_inc()
if config.needs_acts_quantizer_cache:
load_from = ""
if load_path and os.path.exists(load_path.acts):
load_from = load_path.acts
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
load_from = cache.path.acts
if load_from:
logger.info(f"- Loading activation quantizer settings from {load_from}")
quantizer_state_dict = torch.load(load_from)
quantize_llm_activations(
model,
config,
tokenizer=tokenizer,
quantizer_state_dict=quantizer_state_dict,
orig_state_dict=orig_state_dict,
)
else:
logger.info("- Generating activation quantizer settings")
quantizer_state_dict = quantize_llm_activations(
model, config, tokenizer=tokenizer, orig_state_dict=orig_state_dict
)
if cache and cache.dirpath.acts:
logger.info(f"- Saving activation quantizer settings to {cache.path.acts}")
os.makedirs(cache.dirpath.acts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.acts)
load_from = cache.path.acts
if save_dirpath:
if not copy_on_save and load_from:
logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
else:
logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
torch.save(quantizer_state_dict, save_path.acts)
del quantizer_state_dict
else:
logger.info("- No need to generate/load activation quantizer settings")
quantize_llm_activations(model, config, tokenizer=tokenizer, orig_state_dict=orig_state_dict)
tools.logging.Formatter.indent_dec()
del orig_state_dict
gc.collect()
torch.cuda.empty_cache()
return model.module
def main(config: LlmPtqRunConfig, logging_level: int = tools.logging.DEBUG) -> None: # noqa: C901
"""Post-training quantization and evaluation of a large language model.
Args:
config (`LlmPtqConfig`):
The large language model post-training quantization configuration.
logging_level (`int`, *optional*, defaults to `logging.DEBUG`):
The logging level.
"""
config.output.lock()
config.dump(path=config.output.get_running_job_path("config.yaml"))
tools.logging.setup(path=config.output.get_running_job_path("run.log"), level=logging_level)
logger = tools.logging.getLogger(__name__)
# region log configurations
logger.info("=== Configurations ===")
tools.logging.info(config.formatted_str(), logger=logger)
logger.info("=== Dumped Configurations ===")
tools.logging.info(pprint.pformat(config.dump(), indent=2, width=120), logger=logger)
logger.info("=== Output Directory ===")
logger.info(config.output.job_dirpath)
# endregion
logger.info("=== Start Evaluating ===")
logger.info(f"* Building model {config.model.name} from {config.model.path}")
tools.logging.Formatter.indent_inc()
model, tokenizer = config.model.build()
tools.logging.Formatter.indent_dec()
save_dirpath = os.path.join(config.output.running_job_dirpath, "cache")
if config.save_model:
if config.save_model.lower() in ("false", "none", "null", "nil"):
save_model = False
elif config.save_model.lower() in ("true", "default"):
save_dirpath, save_model = os.path.join(config.output.running_job_dirpath, "model"), True
else:
save_dirpath, save_model = config.save_model, True
else:
save_model = False
model = ptq(
model,
tokenizer=tokenizer,
config=config.quant,
cache=config.cache,
load_dirpath=config.load_from,
save_dirpath=save_dirpath,
copy_on_save=config.copy_on_save,
save_model=save_model,
)
# region evaluate model
if not config.skip_eval:
logger.info("* Evaluating model")
eos_token_ids = GenerationConfig.from_pretrained(config.model.path).eos_token_id
if not isinstance(eos_token_ids, list):
eos_token_ids = [eos_token_ids]
tools.logging.Formatter.indent_inc()
results = config.eval.evaluate(
model,
tokenizer,
model_name=config.model.name,
eos_token_ids=eos_token_ids,
output_dirpath=config.output.get_running_job_path("eval"),
)
tools.logging.Formatter.indent_dec()
logger.info(f"* Saving results to {config.output.job_dirpath}")
# dump results
with open(os.path.join(config.output.get_running_job_path("results.json")), "w") as f:
json.dump(results, f, indent=2)
# endregion
config.output.unlock()
if __name__ == "__main__":
config, _, unused_cfgs, unused_args, unknown_args = LlmPtqRunConfig.get_parser().parse_known_args()
if len(unused_cfgs) > 0:
tools.logging.warning(f"Unused configurations: {unused_cfgs}")
if unused_args is not None:
tools.logging.warning(f"Unused arguments: {unused_args}")
assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
try:
main(config, logging_level=tools.logging.DEBUG)
except Exception as e:
tools.logging.Formatter.indent_reset()
tools.logging.error("=== Error ===")
tools.logging.error(traceback.format_exc())
tools.logging.shutdown()
traceback.print_exc()
config.output.unlock(error=True)
raise e
================================================
FILE: deepcompressor/app/llm/quant/__init__.py
================================================
# -*- coding: utf-8 -*-
from .activation import quantize_llm_activations
from .config import LlmQuantCacheConfig, LlmQuantConfig
from .quantizer import LlmActivationQuantizer, LlmWeightQuantizer
from .reorder import reorder_llm
from .rotate import rotate_llm
from .smooth import smooth_llm
from .weight import quantize_llm_weights
================================================
FILE: deepcompressor/app/llm/quant/activation.py
================================================
# -*- coding: utf-8 -*-
"""LLM activation quantization calibration module."""
import gc
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.data.common import TensorType
from deepcompressor.utils import tools
from ..nn import LlmModelStruct, LlmTransformerBlockStruct
from .config import LlmQuantConfig
from .quantizer import LlmActivationQuantizer
from .utils import get_needs_inputs_fn, get_needs_outputs_fn
__all__ = ["quantize_llm_activations"]
@torch.inference_mode()
def quantize_llm_layer_activations( # noqa: C901
layer: LlmTransformerBlockStruct,
config: LlmQuantConfig,
quantizer_state_dict: dict[str, tp.Any],
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
orig_state_dict: dict[str, torch.Tensor] | None = None,
) -> None:
"""Calibrate the activation quantization ranges of modules in a layer.
Args:
layer (`LlmTransformerBlockStruct`):
Layer.
config (`LlmQuantConfig`):
Quantization configuration.
quantizer_state_dict (`dict[str, Any]`):
Activation quantizer state dict.
layer_cache (`dict[str, IOTensorsCache]` or `None`, *optional*, defaults to `None`):
Layer activations cache.
layer_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for the layer.
orig_state_dict (`dict[str, torch.Tensor]` or `None`, *optional*, defaults to `None`):
Original weight state dict.
"""
logger = tools.logging.getLogger(f"{__name__}.ActivationQuant")
logger.debug("- Quantizing layer %s", layer.name)
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
orig_state_dict = orig_state_dict or {}
args_caches: list[
tuple[
str, # key
TensorType,
list[nn.Linear], # modules
str, # module name
nn.Module, # eval module
str, # eval name
dict[str, tp.Any], # eval kwargs
list[tuple[nn.Parameter, torch.Tensor]], # original wgts
]
] = []
In, Out = TensorType.Inputs, TensorType.Outputs
attn, ffn = layer.attn_struct, layer.ffn_struct
# region attn
attn_kwargs = attn.filter_kwargs(layer_kwargs)
if orig_state_dict:
orig_wgts = [
(module.weight, orig_state_dict[f"{module_name}.weight"])
for module_name, module in zip(attn.qkv_proj_names, attn.qkv_proj, strict=True)
] + [(attn.out_proj.weight, orig_state_dict[f"{attn.out_proj_name}.weight"])]
else:
orig_wgts = None
# region qkv_proj (Inputs)
module_name = attn.v_proj_name
module_key, cache_key, modules = attn.qkv_proj_key, f"{module_name}.input", attn.qkv_proj
args_caches.append((module_key, In, modules, module_name, attn, attn.name, attn_kwargs, orig_wgts))
# endregion
# region qkv_attn (Outputs)
orig_proj_wgts = (orig_wgts + orig_wgts) if orig_wgts else None
for idx, module_key in enumerate((attn.q_key, attn.k_key, attn.v_key)):
module = getattr(attn, "qkv"[idx])
module_name = getattr(attn, f"{'qkv'[idx]}_name")
cache_key = f"{module_name}.output"
orig_wgts = orig_proj_wgts[idx : idx + 4] if orig_proj_wgts else None
args_caches.append((module_key, Out, [module], module_name, attn, attn.name, attn_kwargs, orig_wgts))
# endregion
# region out_proj (Inputs)
module_name, module = attn.out_proj_name, attn.out_proj
module_key, cache_key = attn.out_proj_key, f"{module_name}.input"
orig_wgts = [(module.weight, orig_state_dict[f"{module_name}.weight"])] if orig_state_dict else None
args_caches.append((module_key, In, [module], module_name, module, module_name, None, orig_wgts))
# endregion
del orig_wgts
# endregion
# region ffn
# region ffn block projections
for expert_idx in range(ffn.config.num_experts):
expert = ffn.experts[expert_idx]
expert_name = ffn.expert_names[expert_idx]
# region proj 1st in expert (Inputs)
module_name = ffn.up_proj_names[expert_idx]
modules = ffn.up_projs[expert_idx :: ffn.config.num_experts]
module_key, cache_key = ffn.up_proj_key, f"{module_name}.input"
if orig_state_dict:
orig_wgts = [
(module.weight, orig_state_dict[f"{expert_name}.{ffn.up_proj_rnames[module_idx]}.weight"])
for module_idx, module in enumerate(modules)
]
else:
orig_wgts = None
args_caches.append((module_key, In, modules, module_name, expert, module_name, None, orig_wgts))
# endregion
# region proj 2nd in expert (Inputs)
module_name, module = ffn.down_proj_names[expert_idx], ffn.down_projs[expert_idx]
module_key, cache_key = ffn.down_proj_key, f"{module_name}.input"
if orig_state_dict:
orig_wgts = [(module.weight, orig_state_dict[f"{module_name}.weight"])]
else:
orig_wgts = None
args_caches.append((module_key, In, [module], module_name, module, module_name, None, orig_wgts))
# endregion
# endregion
# endregion
quantizers: dict[str, LlmActivationQuantizer] = {}
tools.logging.Formatter.indent_inc()
for module_key, tensor_type, modules, module_name, eval_module, eval_name, eval_kwargs, orig_wgts in args_caches:
if tensor_type == TensorType.Inputs:
cache_key = f"{module_name}.input"
quantizer_config = config.ipts
activations = layer_cache.get(module_name, IOTensorsCache()).inputs
device = modules[0].weight.device
else:
cache_key = f"{module_name}.output"
quantizer_config = config.opts
activations = layer_cache.get(module_name, IOTensorsCache()).outputs
device = attn.out_proj.weight.device
quantizer = LlmActivationQuantizer(
quantizer_config,
channels_dim=-1,
develop_dtype=config.develop_dtype,
key=module_key,
tensor_type=tensor_type,
)
if quantizer.is_enabled():
quantizers[cache_key] = quantizer
if cache_key not in quantizer_state_dict:
logger.debug("- Calibrating %s", cache_key)
quantizer.calibrate_dynamic_range(
modules=modules,
activations=activations,
eval_module=eval_module,
eval_inputs=layer_cache[eval_name].inputs if layer_cache else None,
eval_kwargs=eval_kwargs,
orig_weights=orig_wgts,
)
quantizer_state_dict[cache_key] = quantizer.state_dict()
gc.collect()
torch.cuda.empty_cache()
else:
quantizer.load_state_dict(quantizer_state_dict[cache_key], device=device)
if tensor_type == TensorType.Inputs:
if attn.v_proj_rname in cache_key:
for proj_name in [attn.q_proj_rname, attn.k_proj_rname]:
quantizers[cache_key.replace(attn.v_proj_rname, proj_name)] = quantizer
if ffn.up_proj_rnames[0] in cache_key:
for proj_name in ffn.up_proj_rnames[1:]:
quantizers[cache_key.replace(ffn.up_proj_rnames[0], proj_name)] = quantizer
del quantizer
for name, module in layer.module.named_modules():
module_name = f"{layer.name}.{name}"
ipts_quantizer = quantizers.get(f"{module_name}.input", None)
opts_quantizer = quantizers.get(f"{module_name}.output", None)
needs_quant_ipts = ipts_quantizer is not None and ipts_quantizer.is_enabled()
needs_quant_opts = opts_quantizer is not None and opts_quantizer.is_enabled()
if needs_quant_ipts or needs_quant_opts:
logger.debug(
"- Quantizing %s (%s)",
module_name,
("inputs" if needs_quant_ipts else "")
+ (" and " if needs_quant_ipts and needs_quant_opts else "")
+ ("outputs" if needs_quant_opts else ""),
)
if needs_quant_ipts:
ipts_quantizer.as_hook(is_output=False).register(module)
if needs_quant_opts:
opts_quantizer.as_hook(is_output=True).register(module)
tools.logging.Formatter.indent_dec()
@torch.inference_mode()
def quantize_llm_activations(
model: nn.Module | LlmModelStruct,
config: LlmQuantConfig,
tokenizer: PreTrainedTokenizer | None = None,
quantizer_state_dict: dict[str, tp.Any] | None = None,
orig_state_dict: dict[str, torch.Tensor] | None = None,
) -> dict[str, tp.Any]:
"""Quantize the large foundation model activations.
Args:
model (`nn.Module` or `LlmStruct`):
Model to be quantized.
config (`LlmQuantConfig`):
Quantization configuration.
tokenizer (`PreTrainedTokenizer`, *optional*, defaults to `None`):
Tokenizer.
quantizer_state_dict (`dict[str, Any]`, *optional*, defaults to `None`):
Activation quantizer state dict cache.
orig_state_dict (`dict[str, torch.Tensor]`, *optional*, defaults to `None`):
Original weight state dict
Returns:
`dict[str, Any]`:
Activation quantizer state dict cache.
"""
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
quantizer_state_dict = quantizer_state_dict or {}
with tools.logging.redirect_tqdm():
if not quantizer_state_dict and config.needs_acts_quantizer_cache:
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader(tokenizer).iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model=model, config=config),
needs_outputs_fn=get_needs_outputs_fn(model=model, config=config),
),
desc="quantizing activations",
leave=False,
total=len(model.backbone_struct.layer_structs),
dynamic_ncols=True,
):
quantize_llm_layer_activations(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
orig_state_dict=orig_state_dict,
)
else:
for layer in tqdm(
model.backbone_struct.layer_structs,
desc="quantizing activations",
leave=False,
dynamic_ncols=True,
):
quantize_llm_layer_activations(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
orig_state_dict=orig_state_dict,
)
return quantizer_state_dict
================================================
FILE: deepcompressor/app/llm/quant/config.py
================================================
# -*- coding: utf-8 -*-
"""Quantization config."""
import os
from dataclasses import dataclass, field
import torch
from omniconfig import configclass
from deepcompressor.calib.config import (
QuantRotationConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
SkipBasedChannelOrderConfig,
SmoothTransfomerConfig,
)
from deepcompressor.data.utils.dtype import eval_dtype
from deepcompressor.utils.common import num2str
from ..cache.config import LlmQuantCacheConfig
from ..nn.struct import LlmFeedForwardStruct, LlmSelfAttentionStruct
from .dataset import LlmCalibDataLoaderConfig
from .quantizer import LlmModuleQuantizerConfig
__all__ = ["LlmQuantConfig"]
@configclass
@dataclass
class LlmQuantConfig(LlmModuleQuantizerConfig):
"""Large Language Model Module quantization configuration.
Args:
wgts (`LlmWeightQuantizerConfig`):
The weight quantization configuration.
ipts (`LlmActivationQuantizerConfig`):
The input activation quantization configuration.
opts (`LlmActivationQuantizerConfig`):
The output activation quantization configuration.
calib (`LlmCalibDataLoaderConfig`):
The calibration dataset configuration.
rotation (`QuantRotationConfig` or `None`, *optional*, defaults to `None`):
The quantization rotation configuration.
reorder (`SkipBasedChannelOrderConfig` or `None`, *optional*, defaults to `None`):
The quantization reordering configuration.
smooth (`SmoothTransfomerConfig`, *optional*, defaults to `None`):
The quantization smoothing configuration.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type during quantization.
"""
calib: LlmCalibDataLoaderConfig
rotation: QuantRotationConfig | None = None
reorder: SkipBasedChannelOrderConfig | None = None
smooth: SmoothTransfomerConfig | None = None
develop_dtype: torch.dtype = field(default_factory=lambda s=torch.float32: eval_dtype(s, with_quant_dtype=False))
def __post_init__(self) -> None: # noqa: C901
if self.smooth is not None:
if not self.smooth.enabled_proj and not self.smooth.enabled_attn:
self.smooth = None
if self.rotation is not None and self.reorder is not None:
self.reorder.skips.append("residual")
if self.rotation.transforms:
self.reorder.skips.extend(self.rotation.transforms)
self.reorder.skips = sorted(set(self.reorder.skips))
if self.enabled_ipts:
if self.ipts.enabled_calib_range and self.ipts.calib_range.granularity == SearchBasedCalibGranularity.Group:
self.ipts.calib_range.granularity = SearchBasedCalibGranularity.ChannelGroup
if self.ipts.static:
assert self.ipts.smallest_group_shape[0] == -1, "static quantization requires batch group size to be -1"
if self.enabled_opts:
if self.opts.enabled_calib_range and self.opts.calib_range.granularity == SearchBasedCalibGranularity.Group:
self.opts.calib_range.granularity = SearchBasedCalibGranularity.ChannelGroup
if self.opts.static:
assert self.opts.smallest_group_shape[0] == -1, "static quantization requires batch group size to be -1"
if self.enabled_reorder:
if not self.reorder.dynamic:
qkv_proj_rkey, up_proj_rkey = LlmSelfAttentionStruct.qkv_proj_rkey, LlmFeedForwardStruct.up_proj_rkey
skips_to_remove = []
for skip in self.reorder.skips:
if skip.startswith(qkv_proj_rkey) or skip.endswith(f"_{qkv_proj_rkey}"):
self.reorder.skips.append("residual")
skips_to_remove.append(skip)
elif skip.startswith(up_proj_rkey) or skip.endswith(f"_{up_proj_rkey}"):
self.reorder.skips.append("residual")
skips_to_remove.append(skip)
self.reorder.skips = sorted(set(self.reorder.skips))
for skip in skips_to_remove:
self.reorder.skips.remove(skip)
self.reorder.skips = sorted(set(self.reorder.skips))
@property
def enabled_smooth(self) -> bool:
"""Whether to enable smooth quantization."""
return self.smooth is not None
@property
def enabled_smooth_proj(self) -> bool:
"""Whether to enable xw smooth quantization."""
return self.enabled_smooth and self.smooth.enabled_proj
@property
def enabled_smooth_attn(self) -> bool:
"""Whether to enable yy smooth quantization."""
return self.enabled_smooth and self.smooth.enabled_attn
@property
def enabled_reorder(self) -> bool:
"""Whether to enable channel reorder."""
return self.reorder is not None and self.reorder.is_enabled()
@property
def enabled_rotation(self) -> bool:
"""Whether to enable rotation."""
return self.rotation is not None
@property
def needs_acts_quantizer_cache(self) -> bool:
"""Whether to cache the activations quantizer settings."""
if self.enabled_ipts and self.ipts.needs_calib_data:
return True
if self.enabled_opts and self.opts.needs_calib_data:
return True
return False
def generate_calib_dirname(self) -> str:
name = ""
if self.enabled_rotation:
name += "-rotate"
if self.rotation.random:
name += ".rnd"
if self.enabled_reorder:
name += "-reorder"
if self.reorder.dynamic:
name += ".dyn"
if self.enabled_smooth:
name += "-smooth"
if self.enabled_smooth_proj:
name += ".proj"
if self.enabled_smooth_attn:
name += ".attn"
calib_name = super().generate_calib_dirname()
if calib_name:
name += f"-{calib_name}"
return name[1:] if name else name
def generate_default_dirname(self) -> str: # noqa: C901
"""Generate directory name for a large language model quantization configuration."""
w_names = x_names = {"qkv_proj": "qkv", "out_proj": "out", "up_proj": "fc1", "down_proj": "fc2"}
y_names = {"attn_q": "q", "attn_k": "k", "attn_v": "v"}
skip_name = ""
if self.enabled_opts:
skip_y_name = "+".join(y_names[y] for y in self.opts.skips if y in y_names)
if skip_y_name:
skip_name += f".y.[{skip_y_name}]"
if self.enabled_wgts:
skip_w_name = "+".join(w_names[w] for w in self.wgts.skips if w in w_names)
if skip_w_name:
skip_name += f".w.[{skip_w_name}]"
if self.enabled_ipts:
skip_x_name = "+".join(x_names[x] for x in self.ipts.skips if x in x_names)
if skip_x_name:
skip_name += f".x.[{skip_x_name}]"
if skip_name:
skip_name = "-skip" + skip_name
if self.enabled_wgts and self.wgts.enabled_gptq:
skip_name += "-gptq"
rotation_name = ""
if self.enabled_rotation:
rotation_name = "-rot"
if self.rotation.path:
rotation_name += f".{self.rotation.name}"
elif self.rotation.random:
rotation_name += ".rnd"
if self.rotation.transforms:
rotation_name += ".[+{}]".format("+".join(w_names[w] for w in self.rotation.transforms))
reorder_name = ""
if self.enabled_reorder:
reorder_name = "-rodr"
if self.reorder.strategy == SearchBasedCalibStrategy.Manual:
if self.reorder.channel_metric.value != "xMax":
reorder_name += f".{self.reorder.channel_metric.value}"
if self.reorder.channel_index.value != "Seq":
reorder_name += f".{self.reorder.channel_index.value}"
else:
reorder_name += f".{self.reorder.strategy.name}"
reorders, skips = [], []
for k in w_names.keys() if self.reorder.dynamic else ("residual", "out_proj", "down_proj"):
v = w_names.get(k, "res")
if k in self.reorder.skips:
skips.append(v)
else:
reorders.append(v)
if len(reorders) <= len(skips):
reorder_name += ".[{}]".format("+".join(reorders))
elif skips:
reorder_name += ".skip.[{}]".format("+".join(skips))
smooth_name = ""
if self.enabled_smooth:
smooth_name = "-smth"
if self.smooth.enabled_proj:
smooth_name += ".proj"
if self.smooth.proj.granularity != SearchBasedCalibGranularity.Layer:
smooth_name += f".{self.smooth.proj.granularity.name}"
if self.smooth.proj.strategy != SearchBasedCalibStrategy.Manual:
smooth_name += f".{self.smooth.proj.strategy.name}"
if self.smooth.proj.alpha <= 0:
smooth_name += f".a{num2str(self.smooth.proj.alpha)}"
if self.smooth.proj.beta <= 0:
smooth_name += f".b{num2str(self.smooth.proj.beta)}"
else:
smooth_name += f".a{num2str(self.smooth.proj.alpha)}"
smooth_name += f".b{num2str(self.smooth.proj.beta)}"
xspan_eq_wspan = True
for xspan, wspan in self.smooth.proj.spans:
if xspan != wspan:
xspan_eq_wspan = False
break
if xspan_eq_wspan:
smooth_name += ".[{}]".format("+".join(xspan.name for xspan, _ in self.smooth.proj.spans))
else:
smooth_name += ".[{}]".format(
"+".join(f"x.{xspan.name}.w.{wspan.name}" for xspan, wspan in self.smooth.proj.spans)
)
smooths, skips = [], []
for k, v in w_names.items():
if k in self.smooth.proj.skips:
skips.append(v)
else:
smooths.append(v)
if len(smooths) <= len(skips):
smooth_name += ".[{}]".format("+".join(smooths))
elif skips:
smooth_name += ".skip.[{}]".format("+".join(skips))
if self.smooth.enabled_attn:
smooth_name += ".attn"
if self.smooth.attn.granularity != SearchBasedCalibGranularity.Layer:
smooth_name += f".{self.smooth.attn.granularity.name}"
if self.smooth.attn.strategy != SearchBasedCalibStrategy.Manual:
smooth_name += f".{self.smooth.attn.strategy.name}"
if self.smooth.attn.alpha <= 0:
smooth_name += f".a{num2str(self.smooth.attn.alpha)}"
if self.smooth.attn.beta <= 0:
smooth_name += f".b{num2str(self.smooth.attn.beta)}"
else:
smooth_name += f".a{num2str(self.smooth.attn.alpha)}"
smooth_name += f".b{num2str(self.smooth.attn.beta)}"
xspan_eq_yspan = True
for xspan, yspan in self.smooth.attn.spans:
if xspan != yspan:
xspan_eq_yspan = False
break
if xspan_eq_yspan:
smooth_name += ".[{}]".format("+".join(xspan.name for xspan, _ in self.smooth.attn.spans))
else:
smooth_name += ".[{}]".format(
"+".join(f"x.{xspan.name}.y.{yspan.name}" for xspan, yspan in self.smooth.attn.spans)
)
wrange_name = ""
if (
self.enabled_wgts
and self.wgts.enabled_calib_range
and (self.wgts.calib_range.needs_search or self.wgts.calib_range.ratio != 1)
):
wrange_name = "-w.range"
if self.wgts.calib_range.needs_search:
if self.wgts.calib_range.granularity != SearchBasedCalibGranularity.Group:
wrange_name += f".{self.wgts.calib_range.granularity.name}"
if self.wgts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
wrange_name += f".{self.wgts.calib_range.objective.name}"
if self.wgts.calib_range.degree != 2:
wrange_name += f".d{num2str(self.wgts.calib_range.degree)}"
wrange_name += f".[{num2str(self.wgts.calib_range.max_shrink)}"
wrange_name += f".{num2str(self.wgts.calib_range.max_expand)}"
wrange_name += f".g{self.wgts.calib_range.num_grids}]"
else:
wrange_name += f".r{num2str(self.wgts.calib_range.ratio)}"
if self.wgts.calib_range.skips:
wrange_name += ".skip.[{}]".format("+".join(w_names[w] for w in self.wgts.calib_range.skips))
xrange_name = ""
if (
self.enabled_ipts
and self.ipts.enabled_calib_range
and (self.ipts.calib_range.needs_search or self.ipts.calib_range.ratio != 1)
):
xrange_name = "-x.range"
if self.ipts.calib_range.needs_search:
if self.ipts.calib_range.granularity != SearchBasedCalibGranularity.Group:
xrange_name += f".{self.ipts.calib_range.granularity.name}"
if self.ipts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
xrange_name += f".{self.ipts.calib_range.objective.name}"
if self.ipts.calib_range.degree != 2:
xrange_name += f".d{num2str(self.ipts.calib_range.degree)}"
xrange_name += f".[{num2str(self.ipts.calib_range.max_shrink)}"
xrange_name += f".{num2str(self.ipts.calib_range.max_expand)}"
xrange_name += f".g{self.ipts.calib_range.num_grids}]"
else:
xrange_name += f".r{num2str(self.ipts.calib_range.ratio)}"
if self.ipts.calib_range.skips:
xrange_name += ".skip.[{}]".format("+".join(w_names[w] for w in self.ipts.calib_range.skips))
yrange_name = ""
if (
self.enabled_opts
and self.opts.enabled_calib_range
and (self.opts.calib_range.needs_search or self.opts.calib_range.ratio != 1)
):
yrange_name = "-y.range"
if self.opts.calib_range.needs_search:
if self.opts.calib_range.granularity != SearchBasedCalibGranularity.Group:
yrange_name += f".{self.opts.calib_range.granularity.name}"
if self.opts.calib_range.objective != SearchBasedCalibObjective.OutputsError:
yrange_name += f".{self.opts.calib_range.objective.name}"
if self.opts.calib_range.degree != 2:
yrange_name += f".d{num2str(self.opts.calib_range.degree)}"
yrange_name += f".[{num2str(self.opts.calib_range.max_shrink)}"
yrange_name += f".{num2str(self.opts.calib_range.max_expand)}"
yrange_name += f".g{self.opts.calib_range.num_grids}]"
else:
yrange_name += f".r{num2str(self.opts.calib_range.ratio)}"
if self.opts.calib_range.skips:
yrange_name += ".skip.[{}]".format("+".join(y_names[y] for y in self.opts.calib_range.skips))
name = skip_name + rotation_name + reorder_name + smooth_name + wrange_name + xrange_name + yrange_name
name = name[1:] if name else "default"
name += f"-{self.calib.generate_dirnames()[0]}"
return name
def generate_cache_dirpath(
self, *, root: str, seed: int, default_dtype: torch.dtype = torch.float16
) -> LlmQuantCacheConfig: # noqa: C901
"""Generate the cache paths for the module quantization configuration."""
quant_names = self.generate_dirnames(default_dtype=default_dtype)
w_kernel_names = []
if self.enabled_wgts and self.wgts.enabled_gptq:
w_kernel_names = self.wgts.kernel_gptq.generate_dirnames(prefix="w.kernel")
if self.enabled_rotation:
quant_names.extend(self.rotation.generate_dirnames(prefix="rotate"))
reorder_dirpath = ""
if self.enabled_reorder:
reorder_names = self.reorder.generate_dirnames(prefix="reorder")
quant_names.extend(reorder_names)
reorder_dirpath = os.path.join("reorder", *quant_names)
smooth_dirpath = ""
if self.enabled_smooth:
smooth_names = self.smooth.generate_dirnames(prefix="smooth")
quant_names.extend(smooth_names)
smooth_dirpath = os.path.join("smooth", *quant_names)
quant_names.extend(w_kernel_names)
wgts_dirpath = ""
if self.enabled_wgts and self.wgts.enabled_calib_range:
quant_names.extend(self.wgts.calib_range.generate_dirnames(prefix="w.range"))
wgts_dirpath = os.path.join("wgts", *quant_names)
acts_dirpath = ""
if self.needs_acts_quantizer_cache:
if self.enabled_ipts and self.ipts.enabled_calib_range:
quant_names.extend(self.ipts.calib_range.generate_dirnames(prefix="x.range"))
if self.enabled_opts and self.opts.enabled_calib_range:
quant_names.extend(self.opts.calib_range.generate_dirnames(prefix="y.range"))
acts_dirpath = os.path.join("acts", *quant_names)
cache_dirpath = LlmQuantCacheConfig(
reorder=reorder_dirpath,
smooth=smooth_dirpath,
wgts=wgts_dirpath,
acts=acts_dirpath,
).add_parent_dirs(*self.calib.generate_dirnames())
if self.enabled_rotation:
if self.rotation.path:
cache_dirpath.rotation = ""
else:
cache_dirpath.rotation = os.path.join(
"rotation",
f"seed.{seed}" if self.rotation.random else "hadamard",
)
cache_dirpath.add_parent_dirs(root, "llm", "cache", "quant")
return cache_dirpath
================================================
FILE: deepcompressor/app/llm/quant/dataset.py
================================================
# -*- coding: utf-8 -*-
"""Functions for collecting calibration dataset for quantization."""
import os
import random
import typing as tp
from dataclasses import MISSING, dataclass, field
import torch
import torch.nn as nn
import torch.utils.data
from datasets import load_dataset
from omniconfig import configclass
from transformers import PreTrainedTokenizer
from transformers.cache_utils import Cache
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from transformers.models.t5.modeling_t5 import T5DenseActDense, T5DenseGatedActDense
from deepcompressor.data.cache import IOTensorsCache, ModuleForwardInput, TensorCache
from deepcompressor.data.utils.reshape import LinearReshapeFn
from deepcompressor.dataset.action import CacheAction, ConcatCacheAction
from deepcompressor.dataset.cache import BaseCalibCacheLoader
from deepcompressor.dataset.config import BaseDataLoaderConfig
from ..nn.patch import RotaryEmbedding
from ..nn.struct import LlmModelStruct, LlmTransformerBlockStruct
__all__ = ["LlmCalibDataLoaderConfig", "LlmCalibCacheLoader"]
@configclass
@dataclass(kw_only=True)
class LlmCalibDataLoaderConfig(BaseDataLoaderConfig):
"""Configuration for collecting calibration dataset for quantization.
Args:
data (`str`):
Dataset name.
num_samples (`int`):
Number of dataset samples.
path (`str`):
Path to the dataset.
seq_length (`int`):
Sequence length of each sample.
min_seq_length (`int`, *optional*, defaults to `0`):
Minimum sequence length of each sample.
max_seq_length (`int`, *optional*, defaults to `0`):
Maximum sequence length of each sample.
local_path (`str`, *optional*, defaults to `""`):
Local path to the dataset.
"""
path: str
seq_length: int
min_seq_length: int = 0
max_seq_length: int = 0
local_path: str = ""
batch_size: int = field(init=False, default=1)
def __post_init__(self) -> None:
self.min_seq_length = max(0, self.min_seq_length)
self.max_seq_length = max(0, self.max_seq_length)
self.path = os.path.expanduser(self.path)
self.local_path = os.path.expanduser(self.local_path)
if os.path.exists(self.local_path):
self.path = self.local_path
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Get the names of the configuration fields."""
name = f"{self.data}.{self.num_samples}x{self.seq_length}.[{self.min_seq_length}-{self.max_seq_length}]"
return [f"{prefix}.{name}" if prefix else name]
def build_dataset(self, tokenizer: PreTrainedTokenizer) -> "LlmCalibDataset":
"""Build calibration dataset.
Args:
tokenizer (`PreTrainedTokenizer`):
Tokenizer for encoding text.
Returns:
`LlmCalibDataset`:
Calibration dataset.
"""
return LlmCalibDataset(
tokenizer,
data=self.data,
path=self.path,
num_samples=self.num_samples,
seq_length=self.seq_length,
max_seq_length=self.max_seq_length,
min_seq_length=self.min_seq_length,
)
def build_loader(self, tokenizer: PreTrainedTokenizer) -> "LlmCalibCacheLoader":
"""Build calibration data cache.
Args:
tokenizer (`PreTrainedTokenizer`):
Tokenizer for encoding text.
Returns:
`LlmCalibDataCache`:
Calibration data cache.
"""
return LlmCalibCacheLoader(config=self, tokenizer=tokenizer)
class LlmCalibDataset(torch.utils.data.Dataset):
data: list[torch.Tensor]
def __init__(
self,
tokenizer: PreTrainedTokenizer,
data: str,
path: str,
num_samples: int,
seq_length: int,
max_seq_length: int = -1,
min_seq_length: int = -1,
seed: int = 42,
) -> None:
assert num_samples > 0, "num_samples should be positive"
assert seq_length > 0, "seq_length should be positive"
num_tokens = num_samples * seq_length
assert tokenizer is not None, "tokenizer is required"
if data == "pileval":
dataset = load_dataset(path, split="validation")
else:
raise NotImplementedError(f"Calibration dataset {data} is not supported")
dataset = dataset.shuffle(seed=seed)
rng = random.Random(seed)
seqs, toks = [], 0
for sample in dataset:
line = tokenizer.encode(sample["text"].strip())
length = len(line)
if length == 0:
continue
if min_seq_length > 0 and length < min_seq_length:
continue
if max_seq_length > 0 and length > max_seq_length:
continue
# sample is a tensor of shape (seq_length, )
seq = torch.tensor(line)
if length > seq_length:
tok = rng.randint(0, length - seq_length)
seq = seq[tok : tok + seq_length]
seqs.append(seq)
toks += seq.numel()
if len(seqs) >= num_samples and toks >= num_tokens:
break
# now concatenate all samples and split according to seq_length
seqs = torch.cat(seqs).split(seq_length)
if toks > num_tokens:
seqs = seqs[:-1]
seqs = seqs[:num_samples]
self.data = seqs
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> torch.Tensor:
return self.data[idx]
class LlmCalibCacheLoader(BaseCalibCacheLoader):
"""Cache for collecting calibration dataset for quantizing large language models."""
config: LlmCalibDataLoaderConfig
dataset: LlmCalibDataset
def __init__(self, config: LlmCalibDataLoaderConfig, tokenizer: PreTrainedTokenizer) -> None:
"""Initialize large language model calibration cache loader.
Args:
config (`LlmCalibDataLoaderConfig`):
Configuration for loading calibration dataset.
tokenizer (`PreTrainedTokenizer`):
Tokenizer for encoding text.
"""
super().__init__(dataset=config.build_dataset(tokenizer=tokenizer), batch_size=config.batch_size)
self.batch_size = min(self.batch_size, len(self.dataset))
self.config = config
def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
"""Initialize cache.
Args:
name (`str`):
Module name.
module (`nn.Module`):
Module.
Returns:
`IOTensorsCache`:
Input and output tensors cache.
"""
if isinstance(
module, (nn.Linear, RotaryEmbedding, MixtralSparseMoeBlock, T5DenseActDense, T5DenseGatedActDense)
) or module.__class__.__name__.endswith(("DecoderLayer", "Attention", "MLP")):
return IOTensorsCache(
inputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
)
else:
super()._init_cache(name, module)
def _convert_layer_inputs(
self, m: nn.Module, args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], save_all: bool = False
) -> ModuleForwardInput:
"""Convert layer inputs to module forward input.
Args:
m (`nn.Module`):
Layer.
args (`tuple[Any, ...]`):
Layer input arguments.
kwargs (`dict[str, Any]`):
Layer input keyword arguments.
save_all (`bool`, *optional*, defaults to `False`):
Whether to save all inputs.
Returns:
`ModuleForwardInput`:
Module forward input.
"""
x = args[0].detach().cpu() if save_all else MISSING
return ModuleForwardInput(
args=[x, *args[1:]], kwargs={k: None if isinstance(v, Cache) else v for k, v in kwargs.items()}
)
def iter_samples(self) -> tp.Generator[ModuleForwardInput, None, None]:
"""Iterate over model input samples.
Args:
tokenizer (`nn.Module`):
Tokenizer for encoding text.
Yields:
`ModuleForwardInput`:
Module forward input.
"""
dataloader = torch.utils.data.DataLoader(
self.dataset, batch_size=self.batch_size, shuffle=False, drop_last=True
)
for data in dataloader:
yield ModuleForwardInput(args=(data,))
def iter_layer_activations( # noqa: C901
self,
model: nn.Module | LlmModelStruct,
*args,
action: CacheAction | None = None,
needs_inputs_fn: tp.Callable[[str, nn.Module], bool] | bool | None = True,
needs_outputs_fn: tp.Callable[[str, nn.Module], bool] | bool | None = None,
**kwargs,
) -> tp.Generator[
tuple[
str,
tuple[
LlmTransformerBlockStruct,
dict[str, IOTensorsCache],
dict[str, tp.Any],
],
],
None,
None,
]:
"""Iterate over model activations for each layer.
Args:
model (`nn.Module`):
Model.
action (`CacheAction`, *optional*, defaults to `None`):
Action for caching activations. If ``None``, ``ConcatCacheAction("cpu")`` is used.
needs_inputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `True`):
Function for determining whether to cache inputs for a module given its name and itself.
needs_outputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `None`):
Function for determining whether to cache outputs for a module given its name and itself.
*args: Arguments for ``_iter_samples``.
**kwargs: Keyword arguments for ``_iter_samples``.
Yields:
Generator[
tuple[str, tuple[LlmTransformerBlockStruct, dict[str, IOTensorsCache], dict[str, Any]]],
None,
None
]:
Generator of tuple of
- layer name
- a tuple of
- layer struct,
- input and output caches for each module in the layer,
- layer input keyword arguments.
"""
if isinstance(model, LlmModelStruct):
model_struct = model
model = model_struct.module
else:
model_struct = LlmModelStruct.construct(model)
layers, layer_structs, recomputes, use_prev_layer_outputs = model_struct.get_iter_layer_activations_args()
action = ConcatCacheAction("cpu") if action is None else action
for layer_idx, (layer_name, (layer, layer_cache, layer_inputs)) in enumerate(
self._iter_layer_activations(
model,
*args,
action=action,
layers=layers,
needs_inputs_fn=needs_inputs_fn,
needs_outputs_fn=needs_outputs_fn,
recomputes=recomputes,
use_prev_layer_outputs=use_prev_layer_outputs,
**kwargs,
)
):
layer_kwargs = layer_inputs[0].kwargs
for layer_input in layer_inputs:
for key, value in layer_input.kwargs.items():
if isinstance(value, torch.Tensor):
assert torch.equal(value, layer_kwargs[key])
else:
assert value == layer_kwargs[key]
layer_struct = layer_structs[layer_idx]
assert layer_name == layer_struct.name, f"Expected {layer_struct.name}, got {layer_name}"
assert layer is layer_struct.module
for transformer_block_struct in layer_struct.iter_transformer_block_structs():
for attn_struct in transformer_block_struct.iter_attention_structs():
if attn_struct.v_proj_name in layer_cache:
cache = layer_cache[attn_struct.v_proj_name]
layer_cache[attn_struct.q_proj_name] = cache
layer_cache[attn_struct.k_proj_name] = cache
ffn_struct = transformer_block_struct.ffn_struct
up_proj_names = ffn_struct.up_proj_names
if up_proj_names[0] in layer_cache:
for expert_idx in range(ffn_struct.config.num_experts):
cache = layer_cache[up_proj_names[expert_idx]]
for name in up_proj_names[expert_idx :: ffn_struct.config.num_experts]:
layer_cache[name] = cache
if ffn_struct.config.num_experts == 1 and ffn_struct.name not in layer_cache:
layer_cache[ffn_struct.name] = layer_cache[up_proj_names[0]]
if ffn_struct.config.num_experts > 1 and ffn_struct.name in layer_cache:
layer_cache[ffn_struct.moe_gate_name] = layer_cache[ffn_struct.name]
yield layer_name, (layer_struct, layer_cache, layer_kwargs)
================================================
FILE: deepcompressor/app/llm/quant/quantizer/__init__.py
================================================
# -*- coding: utf-8 -*-
from .config import LlmModuleQuantizerConfig
from .quantizer import LlmActivationQuantizer, LlmWeightQuantizer
================================================
FILE: deepcompressor/app/llm/quant/quantizer/config.py
================================================
# -*- coding: utf-8 -*-
"""Quantizatizer config."""
import typing as tp
from dataclasses import dataclass, field
import torch
from omniconfig import configclass
from deepcompressor.calib.config import SkipBasedDynamicRangeCalibConfig
from deepcompressor.data.dtype import QuantDataType
from deepcompressor.quantizer.config import ProgressiveQuantizerConfig
from deepcompressor.quantizer.kernel import QuantGptqConfig
from deepcompressor.utils.config import EnableConfig, SkipBasedConfig
__all__ = ["LlmQuantizerConfig", "LlmWeightQuantizerConfig", "LlmActivationQuantizerConfig", "LlmModuleQuantizerConfig"]
@configclass
@dataclass
class LlmQuantizerConfig(SkipBasedConfig, ProgressiveQuantizerConfig):
"""Llm Quantizer Configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
intermediate_dtypes (`Sequence[QuantDataType]`, *optional*, defaults to `()`):
The intermediate quantization data types.
intermediate_levels (Sequence[int], *optional*, defaults to `()`):
The intermediate quantization levels.
needs_dequant_saturation (`bool`, *optional*, defaults to `False`):
Whether the dequantization needs saturation.
skips (`Sequence[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
static (`bool`, *optional*, defaults to `False`):
Whether to use static quantization.
kernel_gptq (`QuantGptqConfig` or `None`, *optional*, defaults to `None`):
The GPTQ kernel configuration.
calib_range (`SkipBasedDynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The dynamic range calibration configuration.
"""
static: bool = False
kernel_gptq: QuantGptqConfig | None = None
calib_range: SkipBasedDynamicRangeCalibConfig | None = None
def __post_init__(self) -> None:
super().__post_init__()
if self.quant_dtype is None:
self.static = False
self.kernel_gptq = None
self.calib_range = None
if self.static and self.calib_range is None:
self.calib_range = SkipBasedDynamicRangeCalibConfig()
@property
def enabled_gptq(self) -> bool:
"""Whether quantization kernel calibration is enabled."""
return self.kernel_gptq is not None
@property
def enabled_calib_range(self) -> bool:
"""Whether quantization dynamic range calibration is enabled."""
return self.calib_range is not None
@property
def needs_calib_data(self) -> bool:
return self.enabled_calib_range and (self.calib_range.needs_search or self.static)
def generate_calib_dirname(self) -> str:
"""Generate the name for quantization calibration.
Returns:
str: The name.
"""
name = ""
if self.static:
name += ".static"
if self.enabled_gptq:
name += ".gptq"
if self.enabled_calib_range and (self.calib_range.needs_search or self.calib_range.ratio != 1):
name += ".range"
return name[1:] if name else ""
@configclass
@dataclass
class LlmWeightQuantizerConfig(LlmQuantizerConfig):
"""Llm Weight Quantizer Configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
intermediate_dtypes (`Sequence[QuantDataType]`, *optional*, defaults to `()`):
The intermediate quantization data types.
intermediate_levels (Sequence[int], *optional*, defaults to `()`):
The intermediate quantization levels.
needs_dequant_saturation (`bool`, *optional*, defaults to `False`):
Whether the dequantization needs saturation.
skips (`Sequence[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
kernel_gptq (`QuantGptqConfig` or `None`, *optional*, defaults to `None`):
The GPTQ kernel configuration.
calib_range (`SkipBasedDynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The dynamic range calibration configuration.
"""
static: bool = field(init=False, default=True)
@configclass
@dataclass
class LlmActivationQuantizerConfig(LlmQuantizerConfig):
"""Llm Activation quantization configuration.
Args:
dtype (`QuantDataType` or `None`, *optional*, defaults to `None`):
The quantization data type.
zero_point (`ZeroPointDomain` or `None`, *optional*, defaults to `None`):
The zero-point domain.
group_shapes (`Sequence[Sequence[int]]`, *optional*, defaults to `((-1, -1, -1),)`):
The shapes for per-group quantization.
scale_dtypes (`Sequence[torch.dtype | QuantDataType | None]`, *optional*, defaults to `(None,)`):
The quantization scale data type for per-group quantization.
skips (`Sequence[str]`, *optional*, defaults to `[]`):
The keys of the modules to skip.
static (`bool`, *optional*, defaults to `False`):
Whether to use static quantization.
calib_range (`SkipBasedDynamicRangeCalibConfig` or `None`, *optional*, defaults to `None`):
The dynamic range calibration configuration.
"""
intermediate_dtypes: tp.Sequence[QuantDataType] = field(init=False, default=())
intermediate_levels: tp.Sequence[int] = field(init=False, default=())
needs_dequant_saturation: bool = field(init=False, default=False)
kernel_gptq: None = field(init=False, default=None)
@configclass
@dataclass
class LlmModuleQuantizerConfig(EnableConfig):
"""Llm Module quantization configuration.
Args:
wgts (`LlmWeightQuantizerConfig`):
The weight quantization configuration.
ipts (`LlmActivationQuantizerConfig`):
The input activation quantization configuration.
opts (`LlmActivationQuantizerConfig`):
The output activation quantization configuration.
"""
wgts: LlmWeightQuantizerConfig
ipts: LlmActivationQuantizerConfig
opts: LlmActivationQuantizerConfig
def is_enabled(self) -> bool:
"""Whether the quantization is enabled."""
return self.enabled_wgts or self.enabled_ipts or self.enabled_opts
@property
def enabled_wgts(self) -> bool:
"""Whether to enable weight quantization."""
return self.wgts is not None and self.wgts.is_enabled()
@property
def enabled_ipts(self) -> bool:
"""Whether to enable activation quantization."""
return self.ipts is not None and self.ipts.is_enabled()
@property
def enabled_opts(self) -> bool:
"""Whether to enable activation quantization."""
return self.opts is not None and self.opts.is_enabled()
def generate_dirnames(
self,
*,
prefix: str = "",
shape: torch.Size | tuple[int, ...] = (1024, 1024, 16, 16),
default_dtype: torch.dtype = torch.float16,
**kwargs,
) -> list[str]:
"""Get the directory names of the quantization configuration.
Args:
prefix (`str`, *optional*, defaults to `""`):
The prefix for the directory names.
shape (`torch.Size` or `tuple[int, ...]`, *optional*, defaults to `(1024, 1024, 16, 16)`):
The shape of the tensor to be quantized.
default_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The dtype of the tensor to be quantized.
Returns:
`list[str]`:
The directory names of the quantization configuration.
- The number of effective bits.
- The name of the quantization data type.
- The name of the group shapes.
- The name of the modules to skip.
"""
wgts_names = self.wgts.generate_dirnames(prefix="w", shape=shape, default_dtype=default_dtype)
ipts_names = self.ipts.generate_dirnames(prefix="x", shape=shape, default_dtype=default_dtype)
opts_names = self.opts.generate_dirnames(prefix="y", shape=shape, default_dtype=default_dtype)
names = [
f"{wgts_name}-{ipts_name}-{opts_name}"
for wgts_name, ipts_name, opts_name in zip(wgts_names, ipts_names, opts_names, strict=True)
]
if prefix:
names = [f"{prefix}.[{name}]" for name in names]
return names
def generate_calib_dirname(self) -> str:
"""Generate the name for quantization calibration.
Returns:
`str`:
The name.
"""
name = ""
if self.enabled_wgts:
calib_name = self.wgts.generate_calib_dirname()
if calib_name:
name += f"-w.{calib_name}"
if self.enabled_ipts:
calib_name = self.ipts.generate_calib_dirname()
if calib_name:
name += f"-x.{calib_name}"
if self.enabled_opts:
calib_name = self.opts.generate_calib_dirname()
if calib_name:
name += f"-y.{calib_name}"
return name[1:] if name else name
================================================
FILE: deepcompressor/app/llm/quant/quantizer/quantizer.py
================================================
# -*- coding: utf-8 -*-
"""Tensor Quantizer module."""
import typing as tp
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from deepcompressor.calib.range import calibrate_dynamic_range
from deepcompressor.data.cache import TensorsCache
from deepcompressor.data.common import TensorType
from deepcompressor.data.range import DynamicRange
from deepcompressor.quantizer.kernel import QuantGptqConfig
from deepcompressor.quantizer.processor import Quantizer
from .config import LlmActivationQuantizerConfig, LlmQuantizerConfig, LlmWeightQuantizerConfig
__all__ = ["LlmQuantizer", "LlmWeightQuantizer", "LlmActivationQuantizer"]
@dataclass
class LlmQuantizer(Quantizer):
"""Llm quantizer class.
Args:
config (`LlmQuantizerConfig`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
tensor_type (`TensorType`, *optional*, defaults to `TensorType.Weights`):
The type of the tensor to quantize.
channels_dim (`int` or `None`, *optional*, defaults to `None`):
The dimension of channels.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
low_rank (`QuantLowRankConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch configuration.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
"""
config: LlmQuantizerConfig
kernel: QuantGptqConfig | None = field(init=False)
tensor_type: TensorType = TensorType.Weights
def __post_init__(self) -> None:
self.kernel = self.config.kernel_gptq
def calibrate_dynamic_range(
self,
modules: tp.Sequence[nn.Module],
activations: TensorsCache,
weights: tp.Sequence[nn.Parameter] = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_weights: tp.Sequence[tuple[nn.Parameter, torch.Tensor]] | None = None,
orig_activations: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> tp.Sequence[DynamicRange] | None:
"""Calibrate the dynamic range.
Args:
modules (`Sequence[nn.Module]`):
The modules to calibrate.
activations (`TensorsCache`):
The inputs cache if the tensor type is not outputs, or the outputs cache if the tensor type is outputs.
weights (`Sequence[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights to calibrate.
If not provided, the weights of the modules will be used.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs for evaluation.
If not provided, the `activations` cache will be used.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
If not provided, the module to calibrate will be used.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
orig_weights (`Sequence[tuple[nn.Parameter, torch.Tensor]]` or `None`, *optional*, defaults to `None`):
The original weights.
orig_activations (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original activations.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original evaluation inputs.
Returns:
`Sequence[DynamicRange]` or `None`:
The dynamic ranges of each quantization step.
"""
if (
not self.is_enabled()
or self.config.calib_range is None
or not self.config.calib_range.is_enabled_for(self.key)
):
self.dynamic_range = None
else:
self.dynamic_range = calibrate_dynamic_range(
tensor_type=self.tensor_type,
config=self.config.calib_range,
static=self.config.static,
quantizer=self,
modules=modules,
activations=activations,
weights=weights,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
orig_weights=orig_weights,
orig_activations=orig_activations,
orig_eval_inputs=orig_eval_inputs,
)
return self.dynamic_range
@dataclass
class LlmWeightQuantizer(LlmQuantizer):
"""Llm Weight Quantizer class.
Args:
config (`LlmWeightQuantizerConfig`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
low_rank (`QuantLowRankConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch configuration.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
"""
config: LlmWeightQuantizerConfig
channels_dim: None = field(init=False, default=None)
tensor_type: TensorType = field(init=False, default=TensorType.Weights)
def calibrate_dynamic_range(
self,
module: nn.Module,
inputs: TensorsCache,
weight: nn.Parameter | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_inputs: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> DynamicRange | tuple[DynamicRange, ...]:
"""Calibrate the dynamic range.
Args:
module (`nn.Module`):
The module to calibrate.
inputs (`TensorsCache`):
The inputs cache.
weight (`nn.Parameter` or `None`, *optional*, defaults to `None`):
The weight parameter to calibrate.
If not provided, the weight of the `module` will be used.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs for evaluation.
If not provided, the `activations` cache will be used.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
If not provided, the module to calibrate will be used.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
orig_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original inputs.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original evaluation inputs.
Returns:
`Sequence[DynamicRange]` or `None`:
The dynamic ranges of each quantization step.
"""
return super().calibrate_dynamic_range(
modules=[module],
weights=[weight] if weight is not None else [module.weight],
activations=inputs,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
orig_activations=orig_inputs,
orig_eval_inputs=orig_eval_inputs,
)
@dataclass
class LlmActivationQuantizer(LlmQuantizer):
"""Llm Activation Quantizer class.
Args:
config (`LlmActivationQuantizerConfig`):
The quantizer configuration.
key (`str`, *optional*, defaults to `""`):
The key of the quantizer.
tensor_type (`TensorType`, *optional*, defaults to `TensorType.Inputs`):
The type of the tensor to quantize.
channels_dim (`int` or `None`, *optional*, defaults to `None`):
The dimension of channels.
scale (`torch.Tensor` or `Sequence[torch.Tensor]` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
dynamic_range (`DynamicRange` or `Sequence[DynamicRange]` or `None`, *optional*, defaults to `None`):
The dynamic range.
range_bound (`RangeBound` or `None`, *optional*, defaults to `None`):
The dynamic range bound.
quant_range (`QuantRange` or `None`, *optional*, defaults to `None`):
The quantization range.
default_dtype (`torch.dtype` or `None`, *optional*, defaults to `None`):
The default scale dtype
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The quantization development dtype.
low_rank (`QuantLowRankConfig` or `None`, *optional*, defaults to `None`):
The quantization low-rank branch configuration.
input_packager (`BaseInputPackager` or `None`, *optional*, defaults to `None`):
The input packager, used for unpacking and repacking the input tensor(s).
output_packager (`BaseOutputPackager` or `None`, *optional*, defaults to `None`):
The output packager, used for unpacking and repacking the output tensor(s).
"""
config: LlmActivationQuantizerConfig
tensor_type: TensorType = TensorType.Inputs
def __post_init__(self) -> None:
super().__post_init__()
assert self.tensor_type != TensorType.Weights, "The tensor type cannot be weights."
assert isinstance(self.channels_dim, int), "The channels dimension must be provided."
================================================
FILE: deepcompressor/app/llm/quant/reorder.py
================================================
# -*- coding: utf-8 -*-
"""LLM quantization channel reordering module."""
import gc
import typing as tp
import torch
import torch.nn as nn
import torch.utils
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from deepcompressor.calib.reorder import ChannelOrderCalibrator, ChannelReorderer
from deepcompressor.data.cache import IOTensorsCache, TensorCache, TensorsCache
from deepcompressor.quantizer.processor import Quantizer
from deepcompressor.utils import tools
from ..nn import LlmModelStruct, LlmTransformerBlockStruct
from .config import LlmQuantConfig
from .utils import get_needs_inputs_fn
__all__ = ["reorder_llm"]
def _extend_params_(
params: list[tuple[nn.Parameter, int]],
modules: list[nn.Linear | nn.Embedding, nn.LayerNorm],
out_channels_dim: int | None = None,
in_channels_dim: int | None = None,
) -> list[tuple[nn.Parameter, int]]:
"""Extend the parameters to be reordered."""
if out_channels_dim is not None:
assert in_channels_dim is None
else:
assert in_channels_dim is not None
for module in modules:
if module is None:
continue
if out_channels_dim is not None:
params.append((module.weight, out_channels_dim))
if hasattr(module, "bias") and module.bias is not None:
params.append((module.bias, 0))
else:
params.append((module.weight, in_channels_dim))
return params
@torch.inference_mode()
def reorder_llm_layer( # noqa: C901
layer: LlmTransformerBlockStruct,
config: LlmQuantConfig,
reorder_cache: dict[str, torch.Tensor],
residual_calibrator: ChannelOrderCalibrator | None = None,
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
) -> ChannelOrderCalibrator | None:
"""Calibrate the channel order in a layer.
Args:
layer (`LlmTransformerBlockStruct`):
Large language model layer to be reordered.
config (`LlmQuantConfig`):
Quantization config.
reorder_cache (`dict[str, torch.Tensor]`):
Reorder indexes cache.
residual_calibrator (`ChannelOrderCalibrator` or `None`, *optional*, defaults to `None`):
Channel order calibrator for residual modules.
layer_cache (`dict[str, IOTensorsCache]`, *optional*, defaults to `None`):
Layer activations cache.
layer_kwargs (`dict[str, tp.Any]`, *optional*, defaults to `None`):
Layer keyword arguments.
Returns:
`ChannelOrderCalibrator` or `None`:
Channel order calibrator for residual modules.
"""
logger = tools.logging.getLogger(f"{__name__}.Reorder")
layer_cache = layer_cache or {}
attn = layer.attn_struct
qkv_proj, out_proj = attn.qkv_proj, attn.out_proj
num_heads, num_head_repeats = attn.config.num_query_heads, attn.config.num_head_repeats
# region reorder in attention module
if config.reorder.dynamic and config.reorder.is_enabled_for(attn.qkv_proj_key):
logger.debug("- Reordering %s", attn.qkv_proj_names)
cache_key = attn.name
if cache_key not in reorder_cache:
index = ChannelOrderCalibrator(
config=config.reorder,
weight_quantizer=Quantizer(config.wgts, key=attn.qkv_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=attn.qkv_proj_key),
develop_dtype=config.develop_dtype,
).calibrate(
x_wgts=[m.weight for m in qkv_proj],
x_acts=layer_cache[attn.v_proj_name].inputs if layer_cache else None,
x_mods=qkv_proj,
eval_inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_module=attn.module,
eval_kwargs=attn.filter_kwargs(layer_kwargs),
reorder_wgts=[(m.weight, 1) for m in qkv_proj],
reorder_ipt_mods=[(attn.module, -1, None)],
reorder_opt_mods=[],
)
reorder_cache[cache_key] = index.to(device=torch.device("cpu"))
index = reorder_cache[cache_key]
for proj in qkv_proj:
index = index.to(proj.weight.device)
proj.weight.data = proj.weight.data.index_select(1, index)
ChannelReorderer(index, channels_dim=-1).as_hook().register(attn.module)
gc.collect()
torch.cuda.empty_cache()
if config.reorder.is_enabled_for(attn.out_proj_key):
logger.debug("- Reordering %s", attn.out_proj_name)
cache_key = attn.out_proj_name
if cache_key not in reorder_cache:
index = ChannelOrderCalibrator(
config=config.reorder,
weight_quantizer=Quantizer(config.wgts, key=attn.out_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=attn.out_proj_key),
num_heads=num_heads,
num_head_repeats=num_head_repeats,
develop_dtype=config.develop_dtype,
).calibrate(
x_wgts=[out_proj.weight],
x_acts=layer_cache[cache_key].inputs if layer_cache else None,
x_mods=[out_proj],
eval_inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_module=out_proj,
reorder_wgts=[(out_proj.weight, 1)],
reorder_ipt_mods=[(out_proj, -1, None)],
reorder_opt_mods=[],
)
reorder_cache[cache_key] = index.to(device=torch.device("cpu"))
index = reorder_cache[cache_key]
index = index.to(out_proj.weight.device)
out_proj.weight.data = out_proj.weight.data.index_select(1, index)
v_proj = qkv_proj[2]
if num_heads > 1 and num_head_repeats > 1:
num_channels = index.numel()
head_channels = num_channels // num_heads
index = index.view(num_heads, head_channels)
delta = torch.arange(0, num_channels, head_channels, device=index.device).view(num_heads, 1)
index = index - delta
num_v_channels = num_channels // num_head_repeats
num_v_heads = num_heads // num_head_repeats
index = index.view(num_v_heads, num_head_repeats, head_channels)[:, 0, :]
delta = torch.arange(0, num_v_channels, head_channels, device=index.device).view(num_v_heads, 1)
index = index + delta
index = index.view(-1)
v_proj.weight.data = v_proj.weight.data.index_select(0, index.to(v_proj.weight.device))
if v_proj.bias is not None:
v_proj.bias.data = v_proj.bias.data[index.to(v_proj.bias.device)].contiguous()
gc.collect()
torch.cuda.empty_cache()
# endregion
ffn = layer.ffn_struct
num_experts = ffn.config.num_experts
up_proj, down_proj = ffn.up_projs, ffn.down_projs
# region reorder in feed-forward module
if config.reorder.dynamic and config.reorder.is_enabled_for(ffn.up_proj_key):
logger.debug("- Reordering %s", ffn.name)
cache_key = ffn.name
if cache_key not in reorder_cache:
index = ChannelOrderCalibrator(
config=config.reorder,
weight_quantizer=Quantizer(config.wgts, key=ffn.up_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=ffn.up_proj_key),
develop_dtype=config.develop_dtype,
).calibrate(
x_wgts=[m.weight for m in up_proj],
x_acts=layer_cache[cache_key].inputs if layer_cache else None,
x_mods=up_proj,
eval_inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_module=ffn.module,
reorder_wgts=[(m.weight, 1) for m in up_proj],
reorder_ipt_mods=[(ffn.module, -1, None)],
reorder_opt_mods=[],
)
reorder_cache[cache_key] = index.to(device=torch.device("cpu"))
index = reorder_cache[cache_key]
index = index.to(device=up_proj[0].weight.device)
for fc in up_proj:
fc.weight.data = fc.weight.data.index_select(1, index.to(fc.weight.device))
moe_gate = ffn.moe_gate
if moe_gate is not None:
moe_gate.weight.data = moe_gate.weight.data.index_select(1, index.to(moe_gate.weight.device))
ChannelReorderer(index, channels_dim=-1).as_hook().register(ffn.module)
if config.reorder.is_enabled_for(ffn.down_proj_key):
for expert_idx, (fc2_name, fc2) in enumerate(zip(ffn.down_proj_names, down_proj, strict=True)):
logger.debug("- Reordering module %s", fc2_name)
cache_key = fc2_name
if cache_key not in reorder_cache:
index = ChannelOrderCalibrator(
config=config.reorder,
weight_quantizer=Quantizer(config.wgts, key=ffn.down_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=ffn.down_proj_key),
develop_dtype=config.develop_dtype,
).calibrate(
x_wgts=[fc2.weight],
x_acts=layer_cache[cache_key].inputs if layer_cache else None,
x_mods=[fc2],
eval_inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_module=fc2,
reorder_wgts=[(fc2.weight, 1)],
reorder_ipt_mods=[(fc2, -1, None)],
reorder_opt_mods=[],
)
reorder_cache[cache_key] = index.to(device=torch.device("cpu"))
index = reorder_cache[cache_key]
index = index.to(fc2.weight.device)
fc2.weight.data = fc2.weight.data.index_select(1, index.to(fc2.weight.device))
for fc1 in up_proj[expert_idx::num_experts]:
fc1.weight.data = fc1.weight.data.index_select(0, index.to(fc1.weight.device))
if fc1.bias is not None:
fc1.bias.data = fc1.bias.data[index.to(fc1.bias.device)].contiguous()
gc.collect()
torch.cuda.empty_cache()
# endregion
if residual_calibrator is not None and (
config.reorder.dynamic
or not config.reorder.is_enabled_for(attn.qkv_proj_key)
or not config.reorder.is_enabled_for(ffn.up_proj_key)
):
residual_calibrator = None
if residual_calibrator is not None and "residual" not in reorder_cache:
residual_calibrator.update_channel_metrics(
weights=[m.weight for m in qkv_proj],
inputs=layer_cache[attn.v_proj_name].inputs if layer_cache else None,
)
for expert_idx in range(num_experts):
residual_calibrator.update_channel_metrics(
weights=[m.weight for m in up_proj[expert_idx::num_experts]],
inputs=layer_cache[ffn.up_proj_names[expert_idx]].inputs if layer_cache else None,
)
return residual_calibrator
@torch.inference_mode()
def reorder_llm( # noqa: C901
model: nn.Module | LlmModelStruct,
config: LlmQuantConfig,
tokenizer: PreTrainedTokenizer | None = None,
reorder_cache: dict[str, torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
"""Quantize the large foundation model weights.
Args:
model (`nn.Module` or `LlmStruct`):
Model to be reordered.
config (`LlmQuantConfig`):
Quantization config.
tokenizer (`PreTrainedTokenizer` or `None`, *optional*, defaults to `None`):
Tokenizer.
reorder_cache (`dict[str, torch.Tensor]`, *optional*, defaults to `None`):
Reorder indexes cache.
Returns:
`dict[str, torch.Tensor]`:
Reorder indexes cache.
"""
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
logger = tools.logging.getLogger(f"{__name__}.Reorder")
reorder_cache = {} if reorder_cache is None else reorder_cache
residual_calibrator = None
if "residual" not in reorder_cache and not config.reorder.dynamic and config.reorder.is_enabled_for("residual"):
residual_calibrator = ChannelOrderCalibrator(
config=config.reorder,
weight_quantizer=Quantizer(config.wgts),
input_quantizer=Quantizer(config.ipts, channels_dim=-1),
develop_dtype=config.develop_dtype,
)
with tools.logging.redirect_tqdm():
if not reorder_cache:
calib_cache = config.calib.build_loader(tokenizer)
for _, (layer, layer_cache, layer_kwargs) in tqdm(
calib_cache.iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model=model, config=config),
),
desc="reordering",
leave=False,
total=len(model.backbone_struct.layer_structs),
dynamic_ncols=True,
):
residual_calibrator = reorder_llm_layer(
layer=layer,
config=config,
reorder_cache=reorder_cache,
residual_calibrator=residual_calibrator,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
)
gc.collect()
torch.cuda.empty_cache()
else:
calib_cache = None
for layer in tqdm(
model.backbone_struct.layer_structs,
desc="reordering",
leave=False,
dynamic_ncols=True,
):
residual_calibrator = reorder_llm_layer(
layer=layer,
config=config,
reorder_cache=reorder_cache,
residual_calibrator=residual_calibrator,
)
if residual_calibrator is None:
return reorder_cache
# region add extra params to be reordered
backbone = model.backbone_struct
x_mods: list[nn.Linear] = []
reorder_wgts: list[tuple[nn.Parameter, int]] = []
for layer in backbone.layer_structs:
x_mods.extend(layer.attn_struct.qkv_proj)
x_mods.extend(layer.ffn_struct.up_projs)
_extend_params_(
reorder_wgts,
[
layer.pre_attn_norm,
layer.attn_struct.out_proj,
layer.post_attn_norm,
layer.pre_ffn_norm,
*layer.ffn_struct.down_projs,
layer.post_ffn_norm,
],
out_channels_dim=0,
)
_extend_params_(reorder_wgts, [layer.ffn_struct.moe_gate], in_channels_dim=1)
need_reorder_head = model.head is not None
if backbone.proj_in is not None:
_extend_params_(reorder_wgts, [backbone.proj_in], out_channels_dim=0)
_extend_params_(reorder_wgts, [backbone.embed_positions], out_channels_dim=1)
else:
_extend_params_(reorder_wgts, [backbone.embed_tokens, backbone.embed_positions], out_channels_dim=1)
_extend_params_(reorder_wgts, [backbone.norm_in, backbone.norm_out], out_channels_dim=0)
if backbone.proj_out is not None:
_extend_params_(reorder_wgts, [backbone.proj_out], in_channels_dim=1)
need_reorder_head = False
logger.debug("- Reordering residual modules")
_extend_params_(reorder_wgts, x_mods, in_channels_dim=1)
if "residual" not in reorder_cache:
calib_cache = calib_cache or config.calib.build_loader(tokenizer)
residual_calibrator.init_channel_indexes()
index = residual_calibrator.calibrate(
x_wgts=[m.weight for m in x_mods],
x_acts=None,
eval_inputs=TensorsCache(TensorCache(calib_cache.dataset.data, channels_dim=-1, orig_device="cuda")),
eval_module=model.backbone,
x_mods=x_mods,
reorder_wgts=reorder_wgts,
reorder_ipt_mods=[],
reorder_opt_mods=[(model.backbone, -1, None)] if need_reorder_head else [],
)
reorder_cache["residual"] = index.to(device=torch.device("cpu"))
del x_mods, residual_calibrator, calib_cache
gc.collect()
torch.cuda.empty_cache()
index = reorder_cache["residual"]
for wgt, dim in reorder_wgts:
wgt.data = wgt.data.index_select(dim=dim, index=index.to(wgt.data.device))
if need_reorder_head and not model.config.tie_word_embeddings:
model.head.weight.data = model.head.weight.data.index_select(dim=1, index=index.to(model.head.weight.device))
gc.collect()
torch.cuda.empty_cache()
return reorder_cache
================================================
FILE: deepcompressor/app/llm/quant/rotate.py
================================================
# -*- coding: utf-8 -*-
"""Large Language Model Rotation module."""
import gc
import torch
from tqdm import tqdm
from transformers import PreTrainedModel
from deepcompressor.calib.config import QuantRotationConfig
from deepcompressor.calib.rotate import (
get_rotation_matrix,
hadamard_in_channels,
rotate_in_channels,
rotate_out_channels,
transform_norm_and_linear,
)
from deepcompressor.utils import tools
from ..nn import LlmModelStruct
__all__ = ["rotate_llm"]
@torch.inference_mode()
def rotate_llm( # noqa: C901
model: PreTrainedModel | LlmModelStruct,
/,
config: QuantRotationConfig,
rotation: torch.Tensor | None = None,
) -> torch.Tensor:
"""Rotate the weights of the large language model.
Args:
model (`PreTrainedModel` or `LlmStruct`):
Model to be rotated.
config (`QuantRotationConfig`):
Rotation configuration.
rotation (`torch.Tensor` or `None`, *optional*, defaults to `None`):
Rotation matrix.
Returns:
`torch.Tensor`:
The rotation matrix.
"""
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
devices: list[torch.device] = []
dtypes: list[torch.dtype] = []
linears: list[torch.nn.Linear] = []
size: float = 0
for m in model.module.modules():
if isinstance(m, torch.nn.Linear):
devices.append(m.weight.device)
dtypes.append(m.weight.dtype)
linears.append(m)
size += m.weight.numel() / 1e9
for linear in linears:
linear.to(dtype=torch.float32, device="cpu" if size > 30 else None)
for block in model.iter_transformer_block_structs():
assert not block.post_attn_norms, "Rotation is only supported for models without post-attention norms."
assert not block.post_ffn_norm, "Rotation is only supported for models without post-FFN norms."
logger = tools.logging.getLogger(f"{__name__}.Rotate")
backbone = model.backbone_struct
layers = backbone.layer_structs
# region transform norm and linear
if backbone.norm_in is None:
if backbone.proj_in is None:
prev_modules = [backbone.embed_tokens]
prev_out_channels_dims = 1
if backbone.embed_positions is not None:
prev_modules.append(backbone.embed_positions)
elif backbone.embed_positions is None:
prev_modules = [backbone.proj_in]
prev_out_channels_dims = 0
else:
prev_modules = [backbone.proj_in, backbone.embed_positions]
prev_out_channels_dims = [0, 1]
else:
prev_modules = [backbone.norm_in]
prev_out_channels_dims = 0
with tools.logging.redirect_tqdm():
for layer in tqdm(layers, desc="Transforming norm and linear", dynamic_ncols=True):
logger.debug(f"- Transforming norm and linear in {layer.name}")
transform_norm_and_linear(
parent=layer.module,
norm_name=layer.pre_attn_norm_rname,
next_modules=layer.attn_struct.qkv_proj,
prev_modules=prev_modules,
prev_out_channels_dims=prev_out_channels_dims,
)
prev_modules = [layer.attn_struct.out_proj]
prev_out_channels_dims = 0
transform_norm_and_linear(
parent=layer.module,
norm_name=layer.pre_ffn_norm_rname,
next_modules=layer.ffn_struct.up_projs
+ ([layer.ffn_struct.moe_gate] if layer.ffn_struct.moe_gate is not None else []),
prev_modules=prev_modules,
prev_out_channels_dims=prev_out_channels_dims,
)
prev_modules = layer.ffn_struct.down_projs
prev_out_channels_dims = 0
gc.collect()
torch.cuda.empty_cache()
logger.debug(f"- Transforming {backbone.norm_out_name}")
transform_norm_and_linear(
parent=backbone.module,
norm_name=backbone.norm_out_rname,
next_modules=[model.head if backbone.proj_out is None else backbone.proj_out],
prev_modules=prev_modules,
prev_out_channels_dims=prev_out_channels_dims,
)
# endregion
if rotation is None:
rotation = get_rotation_matrix(backbone.config.num_channels, random=config.random)
# region rotate embeddings
if backbone.proj_in is None:
logger.debug(f"- Rotating {backbone.embed_tokens_name}")
weight = backbone.embed_tokens.weight
rotation = rotation.to(weight.device)
rotate_in_channels(weight, rotation=rotation)
else:
logger.debug(f"- Rotating {backbone.proj_in_name} (out)")
weight = backbone.proj_in.weight
rotation = rotation.to(weight.device)
rotate_out_channels(weight, rotation=rotation, bias=backbone.proj_in.bias)
if backbone.embed_positions is not None:
logger.debug(f"- Rotating {backbone.embed_positions_name}")
weight = backbone.embed_positions.weight
rotation = rotation.to(weight.device)
rotate_in_channels(weight, rotation=rotation)
# endregion
down_proj = []
# region rotate backbone layers
head_rotation = get_rotation_matrix(model.config.num_head_channels, random=config.random)
with tools.logging.redirect_tqdm():
for layer in tqdm(layers, desc="Rotating backbone layers", dynamic_ncols=True):
logger.debug(f"- Rotating {layer.name}")
tools.logging.Formatter.indent_inc()
attn, ffn = layer.attn_struct, layer.ffn_struct
for proj_name, proj in zip(attn.qkv_proj_names, attn.qkv_proj, strict=True):
logger.debug(f"- Rotating {proj_name} (in)")
rotation = rotation.to(proj.weight.device)
rotate_in_channels(proj.weight, rotation=rotation)
logger.debug(f"- Rotating {attn.out_proj_name} (out)")
rotation = rotation.to(attn.out_proj.weight.device)
rotate_out_channels(attn.out_proj.weight, rotation=rotation, bias=attn.out_proj.bias)
if attn.out_proj_key in config.transforms:
logger.debug(f"- Rotating {attn.v_proj_name} (out)")
rotate_out_channels(attn.v_proj.weight, rotation=head_rotation, bias=attn.v_proj.bias)
logger.debug(f"- Rotating {attn.o_proj_name} (in)")
rotate_in_channels(attn.o_proj.weight, rotation=head_rotation)
for fc_name, fc in zip(ffn.up_proj_names, ffn.up_projs, strict=True):
logger.debug(f"- Rotating {fc_name} (in)")
rotation = rotation.to(fc.weight.device)
rotate_in_channels(fc.weight, rotation=rotation)
if ffn.moe_gate is not None:
logger.debug(f"- Rotating {ffn.moe_gate_name} (in)")
rotation = rotation.to(ffn.moe_gate.weight.device)
rotate_in_channels(ffn.moe_gate.weight, rotation=rotation)
for fc_name, fc in zip(ffn.down_proj_names, ffn.down_projs, strict=True):
logger.debug(f"- Rotating {fc_name} (out)")
rotation = rotation.to(fc.weight.device)
rotate_out_channels(fc.weight, rotation=rotation, bias=fc.bias)
if ffn.down_proj_key in config.transforms:
down_proj.extend(ffn.down_projs)
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
if backbone.proj_out is not None:
logger.debug(f"- Rotating {backbone.proj_out_name} (in)")
weight = backbone.proj_out.weight
rotation = rotation.to(weight.device)
rotate_in_channels(weight, rotation=rotation)
logger.debug(f"- Rotating {backbone.proj_out_name} (out)")
rotation = rotation.to(weight.device)
rotate_out_channels(weight, rotation=rotation, bias=backbone.proj_out.bias)
# endregion
if down_proj:
logger.debug(f"- Applying Hadamard transform on {backbone.name}.down_proj (in)")
hadamard_in_channels(down_proj)
if backbone.proj_out is not None:
logger.debug(f"- Rotating {backbone.proj_out_name} (in)")
weight = backbone.proj_out.weight
else:
logger.debug(f"- Rotating {model.head_name} (in)")
weight = model.head.weight
rotation = rotation.to(weight.device)
rotate_in_channels(weight, rotation=rotation)
for device, dtype, linear in zip(devices, dtypes, linears, strict=True):
linear.to(device=device, dtype=dtype)
return rotation.cpu()
================================================
FILE: deepcompressor/app/llm/quant/smooth.py
================================================
# -*- coding: utf-8 -*-
"""LLM smooth quantization module."""
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from deepcompressor.calib.smooth import ActivationSmoother, smooth_attention, smooth_linear_modules
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.quantizer.processor import Quantizer
from deepcompressor.utils import tools
from ..nn.struct import LlmModelStruct, LlmTransformerBlockStruct
from .config import LlmQuantConfig
from .utils import get_needs_inputs_fn, get_needs_outputs_fn
__all__ = ["smooth_llm"]
@torch.inference_mode()
def smooth_llm_layer( # noqa: C901
layer: LlmTransformerBlockStruct,
config: LlmQuantConfig,
smooth_cache: dict[str, torch.Tensor],
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
) -> None:
"""Smooth a large language model layer.
Args:
layer (`LlmTransformerBlockStruct`):
Large language model layer to smooth.
config (`LlmQuantConfig`):
Quantization configuration.
smooth_cache (`dict[str, torch.Tensor]`):
Smoothing scale caches.
layer_caches (`dict[str, IOTensorsCache]` or `None`, *optional*, defaults to `None`):
Activation caches of the layer.
layer_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for the layer.
"""
logger = tools.logging.getLogger(f"{__name__}.SmoothQuant")
logger.debug("- Smoothing %s", layer.name)
tools.logging.Formatter.indent_inc()
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
attn, ffn = layer.attn_struct, layer.ffn_struct
# region attention qk
needs_quant = config.enabled_opts
needs_quant = needs_quant and (config.opts.is_enabled_for(attn.q_key) or config.opts.is_enabled_for(attn.k_key))
if config.smooth.enabled_attn and needs_quant:
logger.debug("- %s.%s", attn.name, attn.k_rkey)
cache_key = f"{attn.name}.{attn.k_rkey}"
smooth_cache[cache_key] = smooth_attention(
k_proj=attn.k_proj,
q_proj=attn.q_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.attn,
query_quantizer=Quantizer(config.opts, channels_dim=-1, key=attn.q_key),
key_quantizer=Quantizer(config.opts, channels_dim=-1, key=attn.k_key),
queries=layer_cache[attn.q_name].outputs if layer_cache else None,
keys=layer_cache[attn.k_name].outputs if layer_cache else None,
attn_q=attn.q,
attn_k=attn.k,
eval_inputs=layer_cache[attn.name].inputs if layer_cache else None,
eval_module=attn,
eval_kwargs=attn.filter_kwargs(layer_kwargs),
num_heads=attn.config.num_query_heads,
num_head_repeats=attn.config.num_head_repeats,
with_rope=attn.config.with_rope,
develop_dtype=config.develop_dtype,
)
# endregion
# region qkv projection
needs_quant = config.enabled_ipts and config.ipts.is_enabled_for(attn.qkv_proj_key)
needs_quant = needs_quant or (config.enabled_wgts and config.wgts.is_enabled_for(attn.qkv_proj_key))
if config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(attn.qkv_proj_key) and needs_quant:
logger.debug("- %s.%s", attn.name, attn.qkv_proj_rkey)
cache_key = attn.v_proj_name
smooth_cache[cache_key] = smooth_linear_modules(
attn.parent.pre_attn_norm,
attn.qkv_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config.wgts, key=attn.qkv_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=attn.qkv_proj_key),
inputs=layer_cache[attn.q_proj_name].inputs if layer_cache else None,
eval_inputs=layer_cache[attn.name].inputs if layer_cache else None,
eval_module=attn,
eval_kwargs=attn.filter_kwargs(layer_kwargs),
develop_dtype=config.develop_dtype,
)
if not attn.parent.pre_attn_norm:
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(attn.qkv_proj)
# endregion
# region output projection
needs_quant = config.enabled_ipts and config.ipts.is_enabled_for(attn.out_proj_key)
needs_quant = needs_quant or (config.enabled_wgts and config.wgts.is_enabled_for(attn.out_proj_key))
if config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(attn.out_proj_key) and needs_quant:
logger.debug("- %s.%s", attn.name, attn.out_proj_rkey)
cache_key = attn.o_proj_name
smooth_cache[cache_key] = smooth_linear_modules(
None if attn.config.linear_attn else attn.v_proj,
attn.o_proj,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config.wgts, key=attn.out_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=attn.out_proj_key),
inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_inputs=layer_cache[cache_key].inputs if layer_cache else None,
eval_module=attn.o_proj,
num_heads=attn.config.num_query_heads,
num_head_repeats=attn.config.num_head_repeats,
develop_dtype=config.develop_dtype,
)
if attn.config.linear_attn:
ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(attn.o_proj)
# endregion
num_experts = ffn.config.num_experts
# region up projection
needs_quant = config.enabled_ipts and config.ipts.is_enabled_for(ffn.up_proj_key)
needs_quant = needs_quant or (config.enabled_wgts and config.wgts.is_enabled_for(ffn.up_proj_key))
if config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(ffn.up_proj_key) and needs_quant:
logger.debug("- %s.%s", ffn.name, ffn.up_proj_rkey)
cache_key = ffn.name
smooth_cache[cache_key] = smooth_linear_modules(
ffn.parent.pre_ffn_norm,
ffn.up_projs,
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config.wgts, key=ffn.up_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=ffn.up_proj_key),
inputs=layer_cache[ffn.name].inputs if layer_cache else None,
eval_inputs=layer_cache[ffn.name].inputs if layer_cache else None,
eval_module=ffn,
extra_modules=[ffn.moe_gate] if num_experts > 1 else None,
develop_dtype=config.develop_dtype,
)
if not ffn.parent.pre_ffn_norm:
hook = ActivationSmoother(smooth_cache[cache_key], channels_dim=-1).as_hook().register(ffn.up_projs)
if num_experts > 1:
hook.register(ffn.moe_gate)
# endregion
# region down projection
needs_quant = config.enabled_ipts and config.ipts.is_enabled_for(ffn.down_proj_key)
needs_quant = needs_quant or (config.enabled_wgts and config.wgts.is_enabled_for(ffn.down_proj_key))
if config.smooth.enabled_proj and config.smooth.proj.is_enabled_for(ffn.down_proj_key) and needs_quant:
for expert_idx in range(num_experts):
logger.debug("- %s.%s", ffn.expert_names[expert_idx], ffn.down_proj_rkey)
cache_key = ffn.down_proj_names[expert_idx]
smooth_cache[cache_key] = smooth_linear_modules(
ffn.up_projs[expert_idx],
ffn.down_projs[expert_idx],
scale=smooth_cache.get(cache_key, None),
config=config.smooth.proj,
weight_quantizer=Quantizer(config.wgts, key=ffn.down_proj_key),
input_quantizer=Quantizer(config.ipts, channels_dim=-1, key=ffn.down_proj_key),
inputs=layer_cache[ffn.down_proj_names[expert_idx]].inputs if layer_cache else None,
eval_inputs=layer_cache[ffn.down_proj_names[expert_idx]].inputs if layer_cache else None,
eval_module=ffn.down_projs[expert_idx],
develop_dtype=config.develop_dtype,
)
# endregion
tools.logging.Formatter.indent_dec()
@torch.inference_mode()
def smooth_llm(
model: nn.Module | LlmModelStruct,
/,
config: LlmQuantConfig,
tokenizer: PreTrainedTokenizer | None = None,
smooth_cache: dict[str, torch.Tensor] | None = None,
) -> dict[str, torch.Tensor]:
"""Smooth the large language model.
Args:
model (`nn.Module` or `LlmStruct`):
Model to be smoothed.
config (`LlmQuantConfig`):
Quantization configuration.
tokenizer (`PreTrainedTokenizer`, *optional*, defaults to `None`):
Tokenizer.
smooth_cache (`dict[str, torch.Tensor]`, *optional*, defaults to `None`):
Smoothing scale caches.
Returns:
`dict[str, torch.Tensor]`:
Dictionary mapping module names to smoothing scales.
"""
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
smooth_cache = smooth_cache or {}
if not smooth_cache:
with tools.logging.redirect_tqdm():
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader(tokenizer).iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model=model, config=config),
needs_outputs_fn=get_needs_outputs_fn(model=model, config=config),
),
desc="smoothing",
leave=False,
total=len(model.backbone_struct.layer_structs),
dynamic_ncols=True,
):
smooth_llm_layer(
layer=layer,
config=config,
smooth_cache=smooth_cache,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
)
else:
for layer in model.backbone_struct.layer_structs:
smooth_llm_layer(layer=layer, config=config, smooth_cache=smooth_cache)
return smooth_cache
================================================
FILE: deepcompressor/app/llm/quant/utils.py
================================================
# -*- coding: utf-8 -*-
"""LLM quantization utils module."""
import typing as tp
import torch.nn as nn
from ..nn.struct import LlmModelStruct
from .quantizer.config import LlmModuleQuantizerConfig
__all__ = ["get_needs_inputs_fn", "get_needs_outputs_fn"]
def get_needs_inputs_fn(model: LlmModelStruct, config: LlmModuleQuantizerConfig) -> tp.Callable[[str, nn.Module], bool]:
"""Get function that checks if the module needs to cache the inputs.
Args:
model (`LlmStruct`):
Model struct.
config (`LlmModuleQuantizerConfig`):
Module quantization config.
Returns:
`Callable[[str, nn.Module], bool]`:
Function to check if the module needs to cache the inputs.
"""
needs_inputs_names = set()
example_layer = model.backbone_struct.layer_structs[0]
attn, ffn = example_layer.attn_struct, example_layer.ffn_struct
if (config.enabled_wgts and config.wgts.is_enabled_for(attn.qkv_proj_key)) or (
config.enabled_ipts and config.ipts.is_enabled_for(attn.qkv_proj_key)
):
needs_inputs_names.add(attn.rname)
needs_inputs_names.add(attn.v_proj_rname)
if (config.enabled_wgts and config.wgts.is_enabled_for(attn.out_proj_key)) or (
config.enabled_ipts and config.ipts.is_enabled_for(attn.out_proj_key)
):
needs_inputs_names.add(attn.o_proj_rname)
if (config.enabled_wgts and config.wgts.is_enabled_for(ffn.up_proj_key)) or (
config.enabled_ipts and config.ipts.is_enabled_for(ffn.up_proj_key)
):
needs_inputs_names.add(ffn.rname)
needs_inputs_names.add(ffn.up_proj_rnames[0])
if (config.enabled_wgts and config.wgts.is_enabled_for(ffn.down_proj_key)) or (
config.enabled_ipts and config.ipts.is_enabled_for(ffn.down_proj_key)
):
needs_inputs_names.add(ffn.down_proj_rnames[0])
if config.enabled_opts:
needs_inputs_names.add(attn.rname)
needs_inputs_names = tuple(needs_inputs_names)
def needs_inputs(name: str, module: nn.Module) -> bool:
return name.endswith(needs_inputs_names)
return needs_inputs
def get_needs_outputs_fn(
model: LlmModelStruct, config: LlmModuleQuantizerConfig
) -> tp.Callable[[str, nn.Module], bool]:
"""Get function that checks if the module needs to cache the outputs.
Args:
model (`LlmStruct`):
Model struct.
config (`LlmModuleQuantizerConfig`):
Module quantization config.
Returns:
`Callable[[str, nn.Module], bool]`:
Function to check if the module needs to cache the outputs.
"""
attn = model.backbone_struct.layer_structs[0].attn_struct
needs_outputs_names = set()
if config.enabled_opts:
needs_outputs_names.add(attn.q_rname)
needs_outputs_names.add(attn.k_rname)
needs_outputs_names.add(attn.v_rname)
needs_outputs_names = tuple(needs_outputs_names)
def needs_outputs(name: str, module: nn.Module) -> bool:
return name.endswith(needs_outputs_names)
return needs_outputs
================================================
FILE: deepcompressor/app/llm/quant/weight.py
================================================
# -*- coding: utf-8 -*-
"""LLM weight quantization calibration module."""
import gc
import typing as tp
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from deepcompressor.data.cache import IOTensorsCache
from deepcompressor.data.zero import ZeroPointDomain
from deepcompressor.utils import tools
from ..nn.struct import LlmModelStruct, LlmSelfAttentionStruct, LlmTransformerBlockStruct
from .config import LlmQuantConfig
from .quantizer import LlmWeightQuantizer
from .utils import get_needs_inputs_fn
__all__ = ["quantize_llm_weights"]
@torch.inference_mode()
def quantize_llm_layer_weights( # noqa: C901
layer: LlmTransformerBlockStruct,
config: LlmQuantConfig,
quantizer_state_dict: dict[str, tp.Any],
layer_cache: dict[str, IOTensorsCache] | None = None,
layer_kwargs: dict[str, tp.Any] | None = None,
return_with_scale_state_dict: bool = False,
) -> dict[str, torch.Tensor | float | None]:
"""Calibrate the weight quantization ranges of modules in a layer.
Args:
layer (`LlmTransformerBlockStruct`):
Layer.
config (`LlmQuantConfig`):
Quantization config.
quantizer_state_dict (`dict[str, Any]`):
Weight quantizer.
layer_cache (`dict[str, IOTensorsCache]` or `None`, *optional*, defaults to `None`):
Layer activations cache.
layer_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for the layer.
return_with_scale_state_dict (bool, *optional*, defaults to `False`):
Whether to return with scale state dict.
Returns:
`dict[str, torch.Tensor | float | None]`:
Scale state dict.
"""
logger = tools.logging.getLogger(f"{__name__}.WeightQuant")
logger.debug("- Quantizing layer %s", layer.name)
tools.logging.Formatter.indent_inc()
layer_cache = layer_cache or {}
layer_kwargs = layer_kwargs or {}
for module_key, module_name, module, parent, field_name in layer.named_key_modules():
assert isinstance(module, nn.Linear)
if field_name in ("q_proj", "k_proj"):
assert isinstance(parent, LlmSelfAttentionStruct)
eval_name, eval_module, eval_kwargs = parent.name, parent.module, parent.filter_kwargs(layer_kwargs)
else:
eval_name, eval_module, eval_kwargs = module_name, module, None
quantizer = LlmWeightQuantizer(config.wgts, develop_dtype=config.develop_dtype, key=module_key)
if quantizer.is_enabled():
if module_name not in quantizer_state_dict:
logger.debug("- Calibrating %s.weight", module_name)
quantizer.calibrate_dynamic_range(
module=module,
inputs=layer_cache[module_name].inputs if layer_cache else None,
eval_inputs=layer_cache[eval_name].inputs if layer_cache else None,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
)
quantizer_state_dict[module_name] = quantizer.state_dict()
gc.collect()
torch.cuda.empty_cache()
scale_state_dict: dict[str, torch.Tensor | float | None] = {}
for module_key, module_name, module, _, _ in layer.named_key_modules():
assert isinstance(module, nn.Linear)
quantizer = LlmWeightQuantizer(config.wgts, develop_dtype=config.develop_dtype, key=module_key)
param_name = f"{module_name}.weight"
if quantizer.is_enabled():
logger.debug("- Quantizing %s", param_name)
quantizer.load_state_dict(quantizer_state_dict[module_name], device=module.weight.device)
result = quantizer.quantize(
module.weight.data,
inputs=layer_cache[module_name].inputs.front() if layer_cache else None,
return_with_dequant=True,
return_with_quant=return_with_scale_state_dict,
)
module.weight.data = result.data
if return_with_scale_state_dict:
scale_state_dict.update(result.scale.state_dict(f"{param_name}.scale"))
zero_name = "scaled_zero" if config.wgts.zero_domain is ZeroPointDomain.PostScale else "zero"
if isinstance(result.zero, torch.Tensor):
scale_state_dict[f"{param_name}.{zero_name}"] = result.zero.to("cpu")
else:
scale_state_dict[f"{param_name}.{zero_name}"] = result.zero
del result
gc.collect()
torch.cuda.empty_cache()
tools.logging.Formatter.indent_dec()
return scale_state_dict
@torch.inference_mode()
def quantize_llm_weights(
model: nn.Module | LlmModelStruct,
config: LlmQuantConfig,
tokenizer: PreTrainedTokenizer | None = None,
quantizer_state_dict: dict[str, tp.Any] | None = None,
return_with_scale_state_dict: bool = False,
) -> tuple[dict[str, tp.Any], dict[str, torch.Tensor | float | None]]:
"""Quantize the large language model weights.
Args:
model (`nn.Module` or `LlmStruct`):
Model to be quantized.
config (`LlmQuantConfig`):
Quantization configuration.
tokenizer (`PreTrainedTokenizer`, *optional*, defaults to `None`):
Tokenizer.
quantizer_state_dict (`dict[str, Any]`, *optional*, defaults to `None`):
Weight quantizer state dict.
return_with_scale_state_dict (bool, *optional*, defaults to `False`):
Whether to return with scale state dict.
Returns:
`tuple[dict[str, Any], dict[str, torch.Tensor | float | None]`:
Weight quantizer cache and scale state dict.
"""
if not isinstance(model, LlmModelStruct):
model = LlmModelStruct.construct(model)
assert isinstance(model, LlmModelStruct)
quantizer_state_dict = quantizer_state_dict or {}
scale_state_dict: dict[str, torch.Tensor | float | None] = {}
with tools.logging.redirect_tqdm():
if config.wgts.enabled_gptq or (not quantizer_state_dict and config.wgts.needs_calib_data):
for _, (layer, layer_cache, layer_kwargs) in tqdm(
config.calib.build_loader(tokenizer).iter_layer_activations(
model,
needs_inputs_fn=get_needs_inputs_fn(model=model, config=config),
),
desc="quantizing weights",
leave=False,
total=len(model.backbone_struct.layer_structs),
dynamic_ncols=True,
):
scale_state_dict.update(
quantize_llm_layer_weights(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
layer_cache=layer_cache,
layer_kwargs=layer_kwargs,
return_with_scale_state_dict=return_with_scale_state_dict,
)
)
else:
for layer in tqdm(
model.backbone_struct.layer_structs, desc="quantizing weights", leave=False, dynamic_ncols=True
):
scale_state_dict.update(
quantize_llm_layer_weights(
layer=layer,
config=config,
quantizer_state_dict=quantizer_state_dict,
return_with_scale_state_dict=return_with_scale_state_dict,
)
)
return quantizer_state_dict, scale_state_dict
================================================
FILE: deepcompressor/backend/__init__.py
================================================
================================================
FILE: deepcompressor/backend/nunchaku/__init__.py
================================================
================================================
FILE: deepcompressor/backend/nunchaku/convert.py
================================================
"""Converts a DeepCompressor state dict to a Nunchaku state dict."""
import argparse
import os
import safetensors.torch
import torch
import tqdm
from .utils import convert_to_nunchaku_w4x4y16_linear_weight, convert_to_nunchaku_w4x16_linear_weight
def convert_to_nunchaku_w4x4y16_linear_state_dict(
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor | None = None,
smooth: torch.Tensor | None = None,
lora: tuple[torch.Tensor, torch.Tensor] | None = None,
shift: torch.Tensor | None = None,
smooth_fused: bool = False,
float_point: bool = False,
subscale: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
if weight.ndim > 2: # pointwise conv
assert weight.numel() == weight.shape[0] * weight.shape[1]
weight = weight.view(weight.shape[0], weight.shape[1])
if scale.numel() > 1:
assert scale.ndim == weight.ndim * 2
assert scale.numel() == scale.shape[0] * scale.shape[2]
scale = scale.view(scale.shape[0], 1, scale.shape[2], 1)
scale_key = "wcscales" if scale.shape[2] == 1 else "wscales"
else:
scale_key = "wtscale"
if subscale is None:
subscale_key = ""
else:
assert subscale.ndim == weight.ndim * 2
assert subscale.numel() == subscale.shape[0] * subscale.shape[2]
assert subscale.numel() > 1
subscale = subscale.view(subscale.shape[0], 1, subscale.shape[2], 1)
subscale_key = "wcscales" if subscale.shape[2] == 1 else "wscales"
if lora is not None and (smooth is not None or shift is not None):
# unsmooth lora down projection
dtype = weight.dtype
lora_down, lora_up = lora
lora_down = lora_down.to(dtype=torch.float64)
if smooth is not None and not smooth_fused:
lora_down = lora_down.div_(smooth.to(torch.float64).unsqueeze(0))
if shift is not None:
bias = torch.zeros([lora_up.shape[0]], dtype=torch.float64) if bias is None else bias.to(torch.float64)
if shift.numel() == 1:
shift = shift.view(1, 1).expand(lora_down.shape[1], 1).to(torch.float64)
else:
shift = shift.view(-1, 1).to(torch.float64)
bias = bias.add_((lora_up.to(dtype=torch.float64) @ lora_down @ shift).view(-1))
bias = bias.to(dtype=dtype)
lora = (lora_down.to(dtype=dtype), lora_up)
weight, scale, bias, smooth, lora, subscale = convert_to_nunchaku_w4x4y16_linear_weight(
weight, scale=scale, bias=bias, smooth=smooth, lora=lora, float_point=float_point, subscale=subscale
)
state_dict: dict[str, torch.Tensor] = {}
state_dict["qweight"] = weight
state_dict[scale_key] = scale
if subscale is not None:
state_dict[subscale_key] = subscale
state_dict["bias"] = bias
state_dict["smooth_orig"] = smooth
state_dict["smooth"] = torch.ones_like(smooth) if smooth_fused else smooth.clone()
if lora is not None:
state_dict["lora_down"] = lora[0]
state_dict["lora_up"] = lora[1]
return state_dict
def convert_to_nunchaku_w4x16_adanorm_single_state_dict(
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
) -> dict[str, torch.Tensor]:
weight, scale, zero, bias = convert_to_nunchaku_w4x16_linear_weight(
weight, scale=scale, bias=bias, adanorm_splits=3
)
state_dict: dict[str, torch.Tensor] = {}
state_dict = {}
state_dict["qweight"] = weight
state_dict["wscales"] = scale
state_dict["wzeros"] = zero
state_dict["bias"] = bias
return state_dict
def convert_to_nunchaku_w4x16_adanorm_zero_state_dict(
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor,
) -> dict[str, torch.Tensor]:
weight, scale, zero, bias = convert_to_nunchaku_w4x16_linear_weight(
weight, scale=scale, bias=bias, adanorm_splits=6
)
state_dict: dict[str, torch.Tensor] = {}
state_dict = {}
state_dict["qweight"] = weight
state_dict["wscales"] = scale
state_dict["wzeros"] = zero
state_dict["bias"] = bias
return state_dict
def update_state_dict(
lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
) -> dict[str, torch.Tensor]:
for rkey, value in rhs.items():
lkey = f"{prefix}.{rkey}" if prefix else rkey
assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
lhs[lkey] = value
return lhs
def convert_to_nunchaku_transformer_block_state_dict(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
smooth_dict: dict[str, torch.Tensor],
branch_dict: dict[str, torch.Tensor],
block_name: str,
local_name_map: dict[str, str | list[str]],
smooth_name_map: dict[str, str],
branch_name_map: dict[str, str],
convert_map: dict[str, str],
float_point: bool = False,
) -> dict[str, torch.Tensor]:
print(f"Converting block {block_name}...")
converted: dict[str, torch.Tensor] = {}
candidates: dict[str, torch.Tensor] = {
param_name: param for param_name, param in state_dict.items() if param_name.startswith(block_name)
}
for converted_local_name, candidate_local_names in tqdm.tqdm(
local_name_map.items(), desc=f"Converting {block_name}", dynamic_ncols=True
):
if isinstance(candidate_local_names, str):
candidate_local_names = [candidate_local_names]
candidate_names = [f"{block_name}.{candidate_local_name}" for candidate_local_name in candidate_local_names]
weight = [candidates[f"{candidate_name}.weight"] for candidate_name in candidate_names]
bias = [candidates.get(f"{candidate_name}.bias", None) for candidate_name in candidate_names]
scale = [scale_dict.get(f"{candidate_name}.weight.scale.0", None) for candidate_name in candidate_names]
subscale = [scale_dict.get(f"{candidate_name}.weight.scale.1", None) for candidate_name in candidate_names]
if len(weight) > 1:
bias = None if all(b is None for b in bias) else torch.concat(bias, dim=0)
if all(s is None for s in scale):
scale = None
else:
if scale[0].numel() == 1: # switch from per-tensor to per-channel scale
assert all(s.numel() == 1 for s in scale)
scale = torch.concat(
[
s.view(-1).expand(weight[i].shape[0]).reshape(weight[i].shape[0], 1, 1, 1)
for i, s in enumerate(scale)
],
dim=0,
)
else:
scale = torch.concat(scale, dim=0)
subscale = None if all(s is None for s in subscale) else torch.concat(subscale, dim=0)
weight = torch.concat(weight, dim=0)
else:
weight, bias, scale, subscale = weight[0], bias[0], scale[0], subscale[0]
smooth = smooth_dict.get(f"{block_name}.{smooth_name_map.get(converted_local_name, '')}", None)
branch = branch_dict.get(f"{block_name}.{branch_name_map.get(converted_local_name, '')}", None)
if branch is not None:
branch = (branch["a.weight"], branch["b.weight"])
if scale is None:
assert smooth is None and branch is None and subscale is None
print(f" - Copying {block_name} weights of {candidate_local_names} as {converted_local_name}.weight")
converted[f"{converted_local_name}.weight"] = weight.clone().cpu()
if bias is not None:
print(f" - Copying {block_name} biases of {candidate_local_names} as {converted_local_name}.bias")
converted[f"{converted_local_name}.bias"] = bias.clone().cpu()
continue
if convert_map[converted_local_name] == "adanorm_single":
print(f" - Converting {block_name} weights of {candidate_local_names} to {converted_local_name}.")
update_state_dict(
converted,
convert_to_nunchaku_w4x16_adanorm_single_state_dict(weight=weight, scale=scale, bias=bias),
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "adanorm_zero":
print(f" - Converting {block_name} weights of {candidate_local_names} to {converted_local_name}.")
update_state_dict(
converted,
convert_to_nunchaku_w4x16_adanorm_zero_state_dict(weight=weight, scale=scale, bias=bias),
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "linear":
smooth_fused = "out_proj" in converted_local_name and smooth_dict.get("proj.fuse_when_possible", True)
shift = [candidates.get(f"{candidate_name[:-7]}.shift", None) for candidate_name in candidate_names]
assert all(s == shift[0] for s in shift)
shift = shift[0]
print(
f" - Converting {block_name} weights of {candidate_local_names} to {converted_local_name}."
f" (smooth_fused={smooth_fused}, shifted={shift is not None}, float_point={float_point})"
)
update_state_dict(
converted,
convert_to_nunchaku_w4x4y16_linear_state_dict(
weight=weight,
scale=scale,
bias=bias,
smooth=smooth,
lora=branch,
shift=shift,
smooth_fused=smooth_fused,
float_point=float_point,
subscale=subscale,
),
prefix=converted_local_name,
)
else:
raise NotImplementedError(f"Conversion of {convert_map[converted_local_name]} is not implemented.")
return converted
def convert_to_nunchaku_flux_single_transformer_block_state_dict(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
smooth_dict: dict[str, torch.Tensor],
branch_dict: dict[str, torch.Tensor],
block_name: str,
float_point: bool = False,
) -> dict[str, torch.Tensor]:
down_proj_local_name = "proj_out.linears.1.linear"
if f"{block_name}.{down_proj_local_name}.weight" not in state_dict:
down_proj_local_name = "proj_out.linears.1"
assert f"{block_name}.{down_proj_local_name}.weight" in state_dict
return convert_to_nunchaku_transformer_block_state_dict(
state_dict=state_dict,
scale_dict=scale_dict,
smooth_dict=smooth_dict,
branch_dict=branch_dict,
block_name=block_name,
local_name_map={
"norm.linear": "norm.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "proj_mlp",
"mlp_fc2": down_proj_local_name,
},
smooth_name_map={
"qkv_proj": "attn.to_q",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "attn.to_q",
"mlp_fc2": down_proj_local_name,
},
branch_name_map={
"qkv_proj": "attn.to_q",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "proj_mlp",
"mlp_fc2": down_proj_local_name,
},
convert_map={
"norm.linear": "adanorm_single",
"qkv_proj": "linear",
"out_proj": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
},
float_point=float_point,
)
def convert_to_nunchaku_flux_transformer_block_state_dict(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
smooth_dict: dict[str, torch.Tensor],
branch_dict: dict[str, torch.Tensor],
block_name: str,
float_point: bool = False,
) -> dict[str, torch.Tensor]:
down_proj_local_name = "ff.net.2.linear"
if f"{block_name}.{down_proj_local_name}.weight" not in state_dict:
down_proj_local_name = "ff.net.2"
assert f"{block_name}.{down_proj_local_name}.weight" in state_dict
context_down_proj_local_name = "ff_context.net.2.linear"
if f"{block_name}.{context_down_proj_local_name}.weight" not in state_dict:
context_down_proj_local_name = "ff_context.net.2"
assert f"{block_name}.{context_down_proj_local_name}.weight" in state_dict
return convert_to_nunchaku_transformer_block_state_dict(
state_dict=state_dict,
scale_dict=scale_dict,
smooth_dict=smooth_dict,
branch_dict=branch_dict,
block_name=block_name,
local_name_map={
"norm1.linear": "norm1.linear",
"norm1_context.linear": "norm1_context.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"norm_added_q": "attn.norm_added_q",
"norm_added_k": "attn.norm_added_k",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_add_out",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": down_proj_local_name,
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": context_down_proj_local_name,
},
smooth_name_map={
"qkv_proj": "attn.to_q",
"qkv_proj_context": "attn.add_k_proj",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_out.0",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": down_proj_local_name,
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": context_down_proj_local_name,
},
branch_name_map={
"qkv_proj": "attn.to_q",
"qkv_proj_context": "attn.add_k_proj",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_add_out",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": down_proj_local_name,
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": context_down_proj_local_name,
},
convert_map={
"norm1.linear": "adanorm_zero",
"norm1_context.linear": "adanorm_zero",
"qkv_proj": "linear",
"qkv_proj_context": "linear",
"out_proj": "linear",
"out_proj_context": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
"mlp_context_fc1": "linear",
"mlp_context_fc2": "linear",
},
float_point=float_point,
)
def convert_to_nunchaku_flux_state_dicts(
state_dict: dict[str, torch.Tensor],
scale_dict: dict[str, torch.Tensor],
smooth_dict: dict[str, torch.Tensor],
branch_dict: dict[str, torch.Tensor],
float_point: bool = False,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
block_names: set[str] = set()
other: dict[str, torch.Tensor] = {}
for param_name in state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
block_names.add(".".join(param_name.split(".")[:2]))
else:
other[param_name] = state_dict[param_name]
block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
print(f"Converting {len(block_names)} transformer blocks...")
converted: dict[str, torch.Tensor] = {}
for block_name in block_names:
convert_fn = convert_to_nunchaku_flux_single_transformer_block_state_dict
if block_name.startswith("transformer_blocks"):
convert_fn = convert_to_nunchaku_flux_transformer_block_state_dict
update_state_dict(
converted,
convert_fn(
state_dict=state_dict,
scale_dict=scale_dict,
smooth_dict=smooth_dict,
branch_dict=branch_dict,
block_name=block_name,
float_point=float_point,
),
prefix=block_name,
)
return converted, other
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantization checkpoint directory.")
parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.")
parser.add_argument("--model-name", type=str, default=None, help="name of the model.")
parser.add_argument("--float-point", action="store_true", help="use float-point 4-bit quantization.")
args = parser.parse_args()
if not args.output_root:
args.output_root = args.quant_path
if args.model_name is None:
assert args.model_path is not None, "model name or path is required."
model_name = args.model_path.rstrip(os.sep).split(os.sep)[-1]
print(f"Model name not provided, using {model_name} as the model name.")
else:
model_name = args.model_name
assert model_name, "Model name must be provided."
assert "flux" in model_name.lower(), "Only Flux models are supported."
state_dict_path = os.path.join(args.quant_path, "model.pt")
scale_dict_path = os.path.join(args.quant_path, "scale.pt")
smooth_dict_path = os.path.join(args.quant_path, "smooth.pt")
branch_dict_path = os.path.join(args.quant_path, "branch.pt")
map_location = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
state_dict = torch.load(state_dict_path, map_location=map_location)
scale_dict = torch.load(scale_dict_path, map_location="cpu")
smooth_dict = torch.load(smooth_dict_path, map_location=map_location) if os.path.exists(smooth_dict_path) else {}
branch_dict = torch.load(branch_dict_path, map_location=map_location) if os.path.exists(branch_dict_path) else {}
converted_state_dict, other_state_dict = convert_to_nunchaku_flux_state_dicts(
state_dict=state_dict,
scale_dict=scale_dict,
smooth_dict=smooth_dict,
branch_dict=branch_dict,
float_point=args.float_point,
)
output_dirpath = os.path.join(args.output_root, model_name)
os.makedirs(output_dirpath, exist_ok=True)
safetensors.torch.save_file(converted_state_dict, os.path.join(output_dirpath, "transformer_blocks.safetensors"))
safetensors.torch.save_file(other_state_dict, os.path.join(output_dirpath, "unquantized_layers.safetensors"))
print(f"Quantized model saved to {output_dirpath}.")
================================================
FILE: deepcompressor/backend/nunchaku/convert_lora.py
================================================
"""Convert LoRA weights to Nunchaku format."""
import argparse
import os
import safetensors
import safetensors.torch
import torch
import tqdm
from ..utils import load_state_dict_in_safetensors, pad
from .convert import update_state_dict
from .utils import NunchakuWeightPacker
def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
c, r = lora_up.shape
assert c % splits == 0
return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
local_name_map: dict[str, str | list[str]],
convert_map: dict[str, str],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
print(f"Converting LoRA branch for block {candidate_block_name}...")
converted: dict[str, torch.Tensor] = {}
packer = NunchakuWeightPacker(bits=4)
for converted_local_name, candidate_local_names in tqdm.tqdm(
local_name_map.items(), desc=f"Converting {candidate_block_name}", dynamic_ncols=True
):
if isinstance(candidate_local_names, str):
candidate_local_names = [candidate_local_names]
# region original LoRA
orig_lora = (
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_down", None),
orig_state_dict.get(f"{converted_block_name}.{converted_local_name}.lora_up", None),
)
if orig_lora[0] is None or orig_lora[1] is None:
assert orig_lora[0] is None and orig_lora[1] is None
orig_lora = None
else:
assert orig_lora[0] is not None and orig_lora[1] is not None
orig_lora = (
packer.unpack_lowrank_weight(orig_lora[0], down=True),
packer.unpack_lowrank_weight(orig_lora[1], down=False),
)
print(f" - Found {converted_block_name} LoRA of {converted_local_name} (rank: {orig_lora[0].shape[0]})")
# endregion
# region extra LoRA
extra_lora = [
(
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_A.weight", None),
extra_lora_dict.get(f"{candidate_block_name}.{candidate_local_name}.lora_B.weight", None),
)
for candidate_local_name in candidate_local_names
]
# if any of the extra LoRA is None, all of them should be None
if any(lora[0] is not None or lora[1] is not None for lora in extra_lora):
# merge extra LoRAs into one LoRA
if len(extra_lora) > 1:
first_lora = None
for lora in extra_lora:
if lora[0] is not None:
assert lora[1] is not None
first_lora = lora
break
assert first_lora is not None
for lora_index in range(len(extra_lora)):
if extra_lora[lora_index][0] is None:
assert extra_lora[lora_index][1] is None
extra_lora[lora_index] = (first_lora[0].clone(), torch.zeros_like(first_lora[1]))
if all(lora[0].equal(extra_lora[0][0]) for lora in extra_lora):
# if all extra LoRAs have the same lora_down, use it
extra_lora_down = extra_lora[0][0]
extra_lora_up = torch.cat([lora[1] for lora in extra_lora], dim=0)
else:
extra_lora_down = torch.cat([lora[0] for lora in extra_lora], dim=0)
extra_lora_up_c = sum(lora[1].shape[0] for lora in extra_lora)
extra_lora_up_r = sum(lora[1].shape[1] for lora in extra_lora)
assert extra_lora_up_r == extra_lora_down.shape[0]
extra_lora_up = torch.zeros((extra_lora_up_c, extra_lora_up_r), dtype=extra_lora_down.dtype)
c, r = 0, 0
for lora in extra_lora:
c_next, r_next = c + lora[1].shape[0], r + lora[1].shape[1]
extra_lora_up[c:c_next, r:r_next] = lora[1]
c, r = c_next, r_next
else:
extra_lora_down, extra_lora_up = extra_lora[0]
extra_lora: tuple[torch.Tensor, torch.Tensor] = (extra_lora_down, extra_lora_up)
print(f" - Found {candidate_block_name} LoRA of {candidate_local_names} (rank: {extra_lora[0].shape[0]})")
# endregion
# region merge LoRA
if orig_lora is None:
if extra_lora is None:
lora = None
else:
print(" - Using extra LoRA")
lora = (extra_lora[0].to(default_dtype), extra_lora[1].to(default_dtype))
elif extra_lora is None:
print(" - Using original LoRA")
lora = orig_lora
else:
lora = (
torch.cat([orig_lora[0], extra_lora[0].to(orig_lora[0].dtype)], dim=0),
torch.cat([orig_lora[1], extra_lora[1].to(orig_lora[1].dtype)], dim=1),
)
print(f" - Merging original and extra LoRA (rank: {lora[0].shape[0]})")
# endregion
if lora is not None:
if convert_map[converted_local_name] == "adanorm_single":
update_state_dict(
converted,
{
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=3), divisor=16, dim=1),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "adanorm_zero":
update_state_dict(
converted,
{
"lora_down": pad(lora[0], divisor=16, dim=0),
"lora_up": pad(reorder_adanorm_lora_up(lora[1], splits=6), divisor=16, dim=1),
},
prefix=converted_local_name,
)
elif convert_map[converted_local_name] == "linear":
update_state_dict(
converted,
{
"lora_down": packer.pack_lowrank_weight(lora[0], down=True),
"lora_up": packer.pack_lowrank_weight(lora[1], down=False),
},
prefix=converted_local_name,
)
return converted
def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
n1 = orig_state_dict[f"{converted_block_name}.out_proj.qweight"].shape[1] * 2
n2 = orig_state_dict[f"{converted_block_name}.mlp_fc2.qweight"].shape[1] * 2
lora_down = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_A.weight"]
lora_up = extra_lora_dict[f"{candidate_block_name}.proj_out.lora_B.weight"]
assert lora_down.shape[1] == n1 + n2
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_A.weight"] = lora_down[:, :n1].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.0.lora_B.weight"] = lora_up.clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_A.weight"] = lora_down[:, n1:].clone()
extra_lora_dict[f"{candidate_block_name}.proj_out.linears.1.lora_B.weight"] = lora_up.clone()
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_A.weight")
extra_lora_dict.pop(f"{candidate_block_name}.proj_out.lora_B.weight")
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm.linear": "norm.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"out_proj": "proj_out.linears.0",
"mlp_fc1": "proj_mlp",
"mlp_fc2": "proj_out.linears.1",
},
convert_map={
"norm.linear": "adanorm_single",
"qkv_proj": "linear",
"out_proj": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
},
default_dtype=default_dtype,
)
def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
converted_block_name: str,
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=converted_block_name,
candidate_block_name=candidate_block_name,
local_name_map={
"norm1.linear": "norm1.linear",
"norm1_context.linear": "norm1_context.linear",
"qkv_proj": ["attn.to_q", "attn.to_k", "attn.to_v"],
"qkv_proj_context": ["attn.add_q_proj", "attn.add_k_proj", "attn.add_v_proj"],
"norm_q": "attn.norm_q",
"norm_k": "attn.norm_k",
"norm_added_q": "attn.norm_added_q",
"norm_added_k": "attn.norm_added_k",
"out_proj": "attn.to_out.0",
"out_proj_context": "attn.to_add_out",
"mlp_fc1": "ff.net.0.proj",
"mlp_fc2": "ff.net.2",
"mlp_context_fc1": "ff_context.net.0.proj",
"mlp_context_fc2": "ff_context.net.2",
},
convert_map={
"norm1.linear": "adanorm_zero",
"norm1_context.linear": "adanorm_zero",
"qkv_proj": "linear",
"qkv_proj_context": "linear",
"out_proj": "linear",
"out_proj_context": "linear",
"mlp_fc1": "linear",
"mlp_fc2": "linear",
"mlp_context_fc1": "linear",
"mlp_context_fc2": "linear",
},
default_dtype=default_dtype,
)
def convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict: dict[str, torch.Tensor],
extra_lora_dict: dict[str, torch.Tensor],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
block_names: set[str] = set()
for param_name in orig_state_dict.keys():
if param_name.startswith(("transformer_blocks.", "single_transformer_blocks.")):
block_names.add(".".join(param_name.split(".")[:2]))
block_names = sorted(block_names, key=lambda x: (x.split(".")[0], int(x.split(".")[-1])))
print(f"Converting {len(block_names)} transformer blocks...")
converted: dict[str, torch.Tensor] = {}
for block_name in block_names:
if block_name.startswith("transformer_blocks"):
convert_fn = convert_to_nunchaku_flux_transformer_block_lowrank_dict
else:
convert_fn = convert_to_nunchaku_flux_single_transformer_block_lowrank_dict
update_state_dict(
converted,
convert_fn(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
converted_block_name=block_name,
candidate_block_name=block_name,
default_dtype=default_dtype,
),
prefix=block_name,
)
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantized model safetensor file")
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="default data type of the converted LoRA weights",
)
args = parser.parse_args()
if not args.output_root:
# output to the parent directory of the quantized model safetensor file
args.output_root = os.path.dirname(args.quant_path)
if args.lora_name is None:
assert args.lora_path is not None, "LoRA name or path must be provided"
lora_name = args.lora_path.rstrip(os.sep).split(os.sep)[-1].replace(".safetensors", "")
print(f"Lora name not provided, using {lora_name} as the LoRA name")
else:
lora_name = args.lora_name
assert lora_name, "LoRA name must be provided."
assert args.quant_path.endswith(".safetensors"), "Quantized model must be a safetensor file"
assert args.lora_path.endswith(".safetensors"), "LoRA weights must be a safetensor file"
orig_state_dict = load_state_dict_in_safetensors(args.quant_path)
extra_lora_dict = load_state_dict_in_safetensors(args.lora_path, filter_prefix="transformer.")
converted = convert_to_nunchaku_flux_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
default_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16,
)
os.makedirs(args.output_root, exist_ok=True)
safetensors.torch.save_file(converted, os.path.join(args.output_root, f"{lora_name}.safetensors"))
print(f"Saved LoRA weights to {args.output_root}.")
================================================
FILE: deepcompressor/backend/nunchaku/utils.py
================================================
# -*- coding: utf-8 -*-
"""Nunchaku backend utilities."""
import torch
from ..tinychat.utils import convert_to_tinychat_w4x16y16_linear_weight
from ..utils import MmaWeightPackerBase, ceil_divide, fp_quantize, pad
__all__ = [
"convert_to_nunchaku_w4x4y16_linear_weight",
"convert_to_nunchaku_w8x8y16_linear_weight",
"convert_to_nunchaku_w4x16_linear_weight",
]
class NunchakuWeightPacker(MmaWeightPackerBase):
def __init__(self, bits: int, warp_n: int = 128):
super().__init__(bits=bits, warp_n=warp_n)
self.num_k_unrolls = 2
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
n, k = weight.shape
assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
# currently, Nunchaku did not check the boundry of unrolled `k` dimension
assert k % (self.mem_k * self.num_k_unrolls) == 0, (
f"input channel size ({k}) should be divisible by "
f"mem_k ({self.mem_k}) * num_k_unrolls ({self.num_k_unrolls})."
)
n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
weight = weight.reshape(
n_tiles,
self.num_n_packs, # 8 when warp_n = 128
self.n_pack_size, # always 2 in nunchaku
self.num_n_lanes, # constant 8
self.reg_n, # constant 1
k_tiles,
self.num_k_packs, # 1
self.k_pack_size, # always 2 in nunchaku
self.num_k_lanes, # constant 4
self.reg_k, # always 8 = 32 bits / 4 bits
)
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
# =>
# (n_tiles, k_tiles, num_k_packs, num_n_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight = weight.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous()
assert weight.shape[4:-2] == (8, 4, 2, 2)
if self.bits == 4:
weight = weight.bitwise_and_(0xF)
shift = torch.arange(0, 32, 4, dtype=torch.int32, device=weight.device)
weight = weight.bitwise_left_shift_(shift)
weight = weight.sum(dim=-1, dtype=torch.int32)
elif self.bits == 8:
weight = weight.bitwise_and_(0xFF)
shift = torch.arange(0, 32, 8, dtype=torch.int32, device=weight.device)
weight = weight.bitwise_left_shift_(shift)
weight = weight.sum(dim=-1, dtype=torch.int32)
else:
raise NotImplementedError(f"weight bits {self.bits} is not supported.")
return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
if self.check_if_micro_scale(group_size=group_size):
return self.pack_micro_scale(scale, group_size=group_size)
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
n = scale.shape[0]
# nunchaku load scales all in one access
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 16 bit as it contains 1 fp16
# min `s_pack_size` set to 2 element, since each lane at least holds 2 accumulator results in `n` dimension
# max `s_pack_size` set to 128b/16b = 8 elements
# for `warp_n = 8`, we have
# `s_pack_size = 2`, `num_s_lanes = 4`, `num_s_packs = 1`
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size = min(max(self.warp_n // self.num_lanes, 2), 8)
num_s_lanes = min(self.num_lanes, self.warp_n // s_pack_size)
num_s_packs = self.warp_n // (s_pack_size * num_s_lanes)
warp_s = num_s_packs * num_s_lanes * s_pack_size
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
# `num_n_lanes = 8 (constant)` generates 8 elements consecutive in `n` dimension
# however, they are held by 4 lanes, each lane holds 2 elements in `n` dimension
# thus, we start from first 4 lanes, assign 2 elements to each lane, until all 8 elements are assigned
# we then repeat the process for the same 4 lanes, until each lane holds `s_pack_size` elements
# finally, we move to next 4 lanes, and repeat the process until all `num_s_lanes` lanes are assigned
# the process is repeated for `num_s_packs` times
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 1 8 9 <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
# 2 3 10 11 <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
# 4 5 12 13 <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
# 6 7 14 15 <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
# 16 17 24 25 <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 22 23 30 31 <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
# ... ...
# 112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
# ...
# 118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
scale = scale.reshape(n // warp_s, num_s_packs, num_s_lanes // 4, s_pack_size // 2, 4, 2, -1)
scale = scale.permute(0, 6, 1, 2, 4, 3, 5).contiguous()
return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
assert scale.max() <= 448, "scale should be less than 448."
assert scale.min() >= -448, "scale should be greater than -448."
assert group_size == 16, "currently only support group size 16."
assert self.insn_k == 64, "insn_k should be 64."
scale = scale.to(dtype=torch.float8_e4m3fn)
n = scale.shape[0]
assert self.warp_n >= 32, "currently only support warp_n >= 32."
# for `[warp_n, warp_k]` weights, we load `[warp_n, warp_k / group_size]` scales
# scale loading is parallelized in `n` dimension, that is,
# `num_s_lanes` in a warp load `num_s_packs` of `s_pack_size` elements, in total `warp_s` elements
# each element in `n` dimension is 32 bit as it contains 4 fp8 in `k` dimension
# min `s_pack_size` set to 1 element
# max `s_pack_size` set to 128b/32b = 4 elements
# for `warp_n = 128`, we have
# `s_pack_size = 4`, `num_s_lanes = 32`, `num_s_packs = 1`
# for `warp_n = 512`, we have
# `s_pack_size = 8`, `num_s_lanes = 32`, `num_s_packs = 2`
s_pack_size = min(max(self.warp_n // self.num_lanes, 1), 4)
num_s_lanes = 4 * 8 # 32 lanes is divided into 4 pieces, each piece has 8 lanes at a stride of 4
num_s_packs = ceil_divide(self.warp_n, s_pack_size * num_s_lanes)
warp_s = num_s_packs * num_s_lanes * s_pack_size
assert warp_s == self.warp_n, "warp_n for scales should be equal to warp_n for weights."
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-scaling-thread-id-b-selection
# we start from first 8 lines at a stride of 4, assign 1 element to each lane, until all 8 elements are assigned
# we then move to next 8 lines at a stride of 4, and repeat the process until all 32 lanes are assigned
# here is an example for `warp_n = 128, s_pack_size = 4, num_s_lanes = 32, num_s_packs = 1`
# wscales store order:
# 0 32 64 96 <-- load by lane 0
# 8 40 72 104 <-- load by lane 1
# 16 48 80 112 <-- load by lane 2
# 24 56 88 120 <-- load by lane 3
# 1 33 65 97 <-- load by lane 4
# ...
# 25 57 81 113 <-- load by lane 7
# ...
# 7 39 71 103 <-- load by lane 28
# ...
# 31 63 95 127 <-- load by lane 31
scale = scale.view(n // warp_s, num_s_packs, s_pack_size, 4, 8, -1, self.insn_k // group_size)
scale = scale.permute(0, 5, 1, 4, 3, 2, 6).contiguous()
return scale.view(-1, n) # the shape is just used for validation
def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
weight = pad(weight, divisor=(pack_n, pack_k), dim=(0, 1))
if down:
r, c = weight.shape
r_packs, c_packs = r // pack_n, c // pack_k
weight = weight.view(r_packs, pack_n, c_packs, pack_k).permute(2, 0, 1, 3)
else:
c, r = weight.shape
c_packs, r_packs = c // pack_n, r // pack_k
weight = weight.view(c_packs, pack_n, r_packs, pack_k).permute(0, 2, 1, 3)
weight = weight.reshape(
c_packs, r_packs, self.n_pack_size, self.num_n_lanes, reg_n, self.k_pack_size, self.num_k_lanes, reg_k
)
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
# =>
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
weight = weight.permute(0, 1, 3, 6, 2, 5, 4, 7).contiguous()
return weight.view(c, r)
def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
pack_n = self.n_pack_size * self.num_n_lanes * reg_n
pack_k = self.k_pack_size * self.num_k_lanes * reg_k
if down:
r_packs, c_packs = r // pack_n, c // pack_k
else:
c_packs, r_packs = c // pack_n, r // pack_k
weight = weight.view(
c_packs, r_packs, self.num_n_lanes, self.num_k_lanes, self.n_pack_size, self.k_pack_size, reg_n, reg_k
)
# (c_packs, r_packs, num_n_lanes, num_k_lanes, n_pack_size, k_pack_size, reg_n, reg_k)
# =>
# (c_packs, r_packs, n_pack_size, num_n_lanes, reg_n, k_pack_size, num_k_lanes, reg_k)
weight = weight.permute(0, 1, 4, 2, 6, 5, 3, 7).contiguous()
weight = weight.view(c_packs, r_packs, pack_n, pack_k)
if down:
weight = weight.permute(1, 2, 0, 3).contiguous().view(r, c)
else:
weight = weight.permute(0, 2, 1, 3).contiguous().view(c, r)
return weight
def check_if_micro_scale(self, group_size: int) -> bool:
return self.insn_k == group_size * 4
def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
def pad_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
if group_size > 0 and scale.numel() > scale.shape[0]:
scale = scale.view(scale.shape[0], 1, -1, 1)
if self.check_if_micro_scale(group_size=group_size):
scale = pad(scale, divisor=(self.warp_n, self.insn_k // group_size), dim=(0, 2), fill_value=1)
else:
scale = pad(scale, divisor=(self.warp_n, self.num_k_unrolls), dim=(0, 2), fill_value=1)
else:
scale = pad(scale, divisor=self.warp_n, dim=0, fill_value=1)
return scale
def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
def convert_to_nunchaku_w4x4y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
bias: torch.Tensor | None = None,
smooth: torch.Tensor | None = None,
lora: tuple[torch.Tensor, torch.Tensor] | None = None,
float_point: bool = False,
subscale: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
tuple[torch.Tensor, torch.Tensor] | None,
torch.Tensor | None,
]:
assert weight.ndim == 2, "weight tensor should be 2D."
device, dtype = weight.device, weight.dtype
assert dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
oc, ic = weight.shape
if scale.numel() == 1:
scale = scale.view(-1).expand(oc).reshape(oc, 1, 1, 1)
per_tensor_scale = True
else:
per_tensor_scale = False
assert scale.ndim == 4, "scale tensor should be 4D."
assert scale.shape[1] == scale.shape[3] == 1
assert scale.shape[0] == oc
ng, gs = scale.shape[2], ic // scale.shape[2]
assert ic == gs * ng, "input channel size should be equal to group size times number of groups."
if subscale is not None:
assert subscale.ndim == 4, "subscale tensor should be 4D."
assert subscale.shape[1] == subscale.shape[3] == 1
assert subscale.shape[0] == oc
nsg, sgs = subscale.shape[2], ic // subscale.shape[2]
assert ic == sgs * nsg, "input channel size should be equal to subgroup size times number of subgroups."
assert gs > sgs and gs % sgs == 0, "group size should be divisible by subgroup size."
else:
nsg, sgs = ng, gs
# region quantize and pack weight tensor
weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs).div_(scale.to(dtype=torch.float32, device=device))
if subscale is not None:
weight = weight.view(oc, 1, nsg, sgs).div_(subscale.to(dtype=torch.float32, device=device))
weight = weight.view(oc, ic)
if float_point:
weight = fp_quantize(weight)
assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
else:
weight = weight.round_()
assert weight.min() >= -8 and weight.max() <= 7, "quantized weight should be in [-8, 7]."
# endregion
bias = torch.zeros([oc, 1], dtype=dtype, device=device) if bias is None else bias.view(-1, 1)
smooth = torch.ones([ic, 1], dtype=dtype, device=device) if smooth is None else smooth.view(-1, 1)
packer = NunchakuWeightPacker(bits=4)
weight = packer.pad_weight(weight.to(dtype=torch.int32))
scale = packer.pad_scale(scale.to(dtype=dtype), group_size=gs)
if subscale is not None:
subscale = packer.pad_scale(subscale.to(dtype=dtype), group_size=sgs)
bias = packer.pad_scale(bias.to(dtype=dtype), group_size=-1)
smooth = packer.pad_scale(smooth.to(dtype=dtype), group_size=-1)
weight = packer.pack_weight(weight)
scale = packer.pack_scale(scale, group_size=gs if gs < ic else -1)
if subscale is not None:
subscale = packer.pack_scale(subscale, group_size=sgs if sgs < ic else -1)
bias = packer.pack_scale(bias, group_size=-1)
smooth = packer.pack_scale(smooth, group_size=-1)
if lora is not None:
lora_down = packer.pack_lowrank_weight(packer.pad_lowrank_weight(lora[0], down=True), down=True)
lora_up = packer.pack_lowrank_weight(packer.pad_lowrank_weight(lora[1], down=False), down=False)
lora = (lora_down, lora_up)
if per_tensor_scale:
scale = scale.view(-1)[0].view([1])
return weight, scale, bias, smooth, lora, subscale
def convert_to_nunchaku_w8x8y16_linear_weight(
weight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert weight.ndim == 2, "weight tensor should be 2D."
device, dtype = weight.device, weight.dtype
assert dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
oc, ic = weight.shape
if scale.numel() == 1:
scale = scale.view(-1).expand(oc)
scale = scale.reshape(oc, 1)
weight = weight.to(dtype=torch.float32)
weight = weight.div_(scale.to(dtype=torch.float32, device=device)).round_().to(torch.int32).view(oc, ic)
assert weight.min() >= -128 and weight.max() <= 127, "quantized weight should be in [-128, 127]."
# endregion
bias = torch.zeros([oc, 1], dtype=dtype, device=device) if bias is None else bias.view(-1, 1)
packer = NunchakuWeightPacker(bits=8)
weight = packer.pack_weight(packer.pad_weight(weight))
scale = packer.pack_scale(packer.pad_scale(scale.to(dtype=dtype), group_size=-1), group_size=-1)
bias = packer.pack_scale(packer.pad_scale(bias.to(dtype=dtype), group_size=-1), group_size=-1).view(-1)
return weight, scale, bias
def convert_to_nunchaku_w4x16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
adanorm_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
oc, ic = weight.shape
assert scale.ndim == 4, "scale tensor should be 4D."
assert scale.shape[0] == oc
assert scale.shape[1] == scale.shape[3] == 1
ng = scale.shape[2]
if bias is None:
bias = torch.zeros([oc], dtype=weight.dtype, device=weight.device)
assert oc % adanorm_splits == 0, "output channel size should be divisible by splits."
if adanorm_splits > 1:
weight = weight.view(adanorm_splits, oc // adanorm_splits, ic).transpose(0, 1).reshape(oc, ic)
scale = scale.view(adanorm_splits, oc // adanorm_splits, ng).transpose(0, 1).reshape(oc, 1, ng, 1)
bias = bias.reshape(adanorm_splits, oc // adanorm_splits).transpose(0, 1)
delta = [0] * adanorm_splits
delta[1] = delta[-2] = 1
bias = bias.add_(torch.tensor(delta, dtype=bias.dtype, device=bias.device))
bias = bias.reshape(oc)
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight=weight, scale=scale, zero=torch.full_like(scale, 7) if zero is None else zero, zero_pre_scaled=True
)
weight = weight.view(torch.int32)
return weight, scale, zero, bias
================================================
FILE: deepcompressor/backend/qserve/__init__.py
================================================
================================================
FILE: deepcompressor/backend/qserve/convert.py
================================================
# -*- coding: utf-8 -*-
"""QServe state dict converter module."""
import argparse
import os
import torch
import tqdm
from .utils import convert_to_qserve_w4x8y16_linear_weight, convert_to_qserve_w8x8y16_linear_weight
__all__ = ["convert_to_qserve_state_dict"]
def convert_to_qserve_w4x8y16_linear_state_dict(
param_name: str,
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
subscale: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> dict[str, torch.Tensor]:
"""Convert a weight tensor to QServe W4-X8-Y16 linear state dictionary.
Args:
param_name (`str`):
parameter name.
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
subscale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
subscale tensor for the weight tensor.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`dict[str, torch.Tensor]`:
state dictionary for the quantized weight tensor.
"""
module_name = param_name[:-7]
weight, scale, zero, subscale = convert_to_qserve_w4x8y16_linear_weight(
weight, scale=scale, zero=zero, subscale=subscale, zero_pre_scaled=zero_pre_scaled
)
state_dict: dict[str, torch.Tensor] = {}
state_dict[f"{module_name}.qweight"] = weight.cpu()
state_dict[f"{module_name}.s1_scales"] = scale.cpu()
if subscale is None:
state_dict[f"{module_name}.s1_szeros"] = zero.cpu()
else:
state_dict[f"{module_name}.s2_scales"] = subscale.cpu()
state_dict[f"{module_name}.s2_zeros"] = zero.cpu()
return state_dict
def convert_to_qserve_w8x8y16_linear_state_dict(
param_name: str, weight: torch.Tensor, scale: torch.Tensor
) -> dict[str, torch.Tensor]:
"""Convert a weight tensor to QServe W8-X8-Y16 linear state dictionary.
Args:
param_name (`str`):
parameter name.
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
Returns:
`dict[str, torch.Tensor]`:
state dictionary for the quantized weight tensor.
"""
module_name = param_name[:-7]
weight, scale = convert_to_qserve_w8x8y16_linear_weight(weight, scale=scale)
state_dict: dict[str, torch.Tensor] = {}
state_dict[f"{module_name}.weight"] = weight.cpu()
state_dict[f"{module_name}.dequant_scale"] = scale.cpu()
return state_dict
def convert_to_qserve_state_dict(
state_dict: dict[str, torch.Tensor], scale_dict: dict[str, torch.Tensor], weight_bits: int
) -> dict[str, torch.Tensor]:
assert weight_bits in [4, 8], "weight bits should be 4 or 8."
scales: dict[str, dict[tuple[int, ...], torch.Tensor]] = {}
zeros: dict[str, tuple[torch.Tensor | None, bool]] = {}
print("Loading scale tensors...")
for name, tensor in tqdm.tqdm(scale_dict.items(), desc="Loading scale tensors", leave=False, dynamic_ncols=True):
print(f" - Loading tensor {name} (dtype: {tensor.dtype}, shape: {tensor.shape}, device: {tensor.device})")
if name.endswith("zero"):
# this is a zero point tensor
zero = None if tensor is None or all(t.item() == 0 for t in tensor.flatten()) else tensor
if name.endswith(".scaled_zero"):
zeros[name[:-12]] = (zero, False) # zero point tensor is post-scaled
else:
zeros[name[:-5]] = (zero, True) # zero point tensor is pre-scaled
else:
assert ".weight.scale" in name
# this is a scale tensor
idx = name.index(".weight.scale")
param_name = name[: idx + 7]
scale_level = tuple(map(int, name[idx + 14 :].split(".")))
scales.setdefault(param_name, {})[scale_level] = tensor
for param_name in zeros.keys():
assert param_name in state_dict, f"zero point tensor {param_name} not found in state dict."
assert param_name in scales, f"scale tensor {param_name} not found in scale dict."
converted: dict[str, torch.Tensor] = {}
print("Converting state dict...")
for param_name, param in tqdm.tqdm(state_dict.items(), desc="Converting state dict", dynamic_ncols=True):
if param_name in scales:
print(f" - Converting {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
weight = param.data.clone()
if param_name in zeros:
zero, zero_pre_scaled = zeros[param_name]
zero = zero.clone() if zero is not None else None
else:
zero, zero_pre_scaled = None, False
level_scales = sorted(scales[param_name].items(), key=lambda x: x[0])
assert len(level_scales) <= 2, "more than two scale levels are not supported."
scale = level_scales[0][1].clone()
subscale = level_scales[1][1].clone() if len(level_scales) > 1 else None
if weight_bits == 4:
converted.update(
convert_to_qserve_w4x8y16_linear_state_dict(
param_name,
weight,
scale=scale,
zero=zero,
subscale=subscale,
zero_pre_scaled=zero_pre_scaled,
)
)
else:
assert zero is None, "zero point tensor is not supported for W8 quantization."
assert subscale is None, "subscale tensor is not supported for W8 quantization."
converted.update(convert_to_qserve_w8x8y16_linear_state_dict(param_name, weight, scale=scale))
else:
if isinstance(param, torch.Tensor):
print(f" - Copying {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
converted[param_name] = param.clone().cpu()
else:
print(f" - Copying {param_name} (type: {type(param)}, value: {param})")
converted[param_name] = param
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantization checkpoint directory.")
parser.add_argument("--weight-bits", type=int, required=True, help="quantized weight bits.")
parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.")
parser.add_argument("--model-name", type=str, default=None, help="name of the model.")
parser.add_argument("--model-path", type=str, default=None, help="path to the huggingface model directory.")
parser.add_argument(
"--copy-on-save",
action="store_true",
help="copy the original tokenizer and configuration files to the output directory.",
)
args = parser.parse_args()
if not args.output_root:
args.output_root = args.quant_path
if args.model_name is None:
assert args.model_path is not None, "model name or path is required."
model_name = args.model_path.rstrip(os.sep).split(os.sep)[-1]
print(f"Model name not provided. Using model name {model_name}.")
else:
model_name = args.model_name
assert model_name, "model name is required."
model_name = f"{model_name}-w{args.weight_bits}a8"
output_dirpath = os.path.join(args.output_root, model_name)
output_path = os.path.join(output_dirpath, "quant_model.pt")
state_dict = torch.load(
os.path.join(args.quant_path, "model.pt"),
map_location="cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu",
)
scale_dict = torch.load(os.path.join(args.quant_path, "scale.pt"), map_location="cpu")
converted = convert_to_qserve_state_dict(state_dict, scale_dict, weight_bits=args.weight_bits)
os.makedirs(output_dirpath, exist_ok=True)
torch.save(converted, output_path)
if args.model_path and os.path.exists(args.model_path):
for filename in os.listdir(args.model_path):
if filename == "tokenizer.model" or (
filename.endswith(".json") and filename != "pytorch_model.bin.index.json"
):
filepath = os.path.abspath(os.path.join(args.model_path, filename))
if args.copy_on_save:
os.system(f"cp {filepath} {output_dirpath}/")
else:
os.system(f"ln -s {filepath} {output_dirpath}/{filename}")
print(f"Quantized model checkpoint saved to {output_path}.")
print(f"Quantized model saved to {output_dirpath}.")
print(f"Quantized model checkpoint saved to {output_path}.")
print(f"Quantized model saved to {output_dirpath}.")
================================================
FILE: deepcompressor/backend/qserve/utils.py
================================================
# -*- coding: utf-8 -*-
"""QServe backend utilities."""
import torch
from ..utils import MmaWeightPackerBase
__all__ = ["convert_to_qserve_w4x8y16_linear_weight", "convert_to_qserve_w8x8y16_linear_weight"]
class QServePacker(MmaWeightPackerBase):
def __init__(self):
super().__init__(bits=8, warp_n=32)
assert self.num_n_packs >= 2 and self.num_n_packs % 2 == 0, (
f"num_n_packs should be even, but got {self.num_n_packs}."
)
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
assert weight.min() >= 0, "quantized weight should be non-negative."
assert weight.max() <= 15, "quantized weight should be less than 16."
assert weight.dtype == torch.uint8, f"quantized weight should be torch.uint8, but got {weight.dtype}."
n, k = weight.shape
assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
assert k % self.mem_k == 0, f"input channel size ({k}) should be divisible by mem_k ({self.mem_k})."
n_tiles, k_tiles = n // self.mem_n, k // self.mem_k
weight = weight.reshape(
n_tiles,
self.num_n_packs, # num_n_packs = 2 when warp_n = 32
self.n_pack_size, # always 2 in QServe
self.num_n_lanes, # constant 8
self.reg_n, # constant 1
k_tiles,
self.num_k_packs, # constant 1
self.k_pack_size, # always 2
self.num_k_lanes, # constant 4
self.reg_k, # always 4 = 32 bits / 8 bits in QServe
)
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
# =>
# (num_n_packs, n_tiles, k_tiles, num_k_packs, num_n_lanes, num_k_lanes, k_pack_size, n_pack_size, reg_n, reg_k)
weight = weight.permute(1, 0, 5, 6, 3, 8, 7, 2, 4, 9).contiguous()
assert weight.shape[4:-2] == (8, 4, 2, 2)
weight = (weight[1] << 4) + weight[0]
return weight.view(torch.int8).view(n, k // 2)
def pack_scale(
self, scale: torch.Tensor, zero: torch.Tensor | None = None, subscale: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
scale = scale.view(-1)
n = scale.shape[0]
if subscale is None:
zero = zero.view(-1)
else:
assert subscale.dtype == torch.int8, f"subscale should be torch.int8, but got {subscale.dtype}."
view_shape = (n // self.mem_n, self.num_n_packs, self.n_pack_size, self.num_n_lanes, self.reg_n, -1)
# (n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n, -1)
# =>
# (-1, n_tiles, num_n_packs, num_n_lanes, n_pack_size, reg_n)
subscale = subscale.view(view_shape).permute(5, 0, 1, 3, 2, 4).contiguous().view(-1, n)
zero = zero.view(view_shape).permute(5, 0, 1, 3, 2, 4).contiguous().view(-1, n)
return scale, zero, subscale
def convert_to_qserve_w4x8y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
subscale: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Convert a weight tensor to QServe W4-X8-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
subscale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
subscale tensor for the weight tensor.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]`:
packed quantized weight tensor, scale tensor, zero point tensor, and subscale tensor.
"""
dtype = weight.dtype
assert dtype == torch.float16, "currently qserve only supports fp16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=weight.device)
zero = zero.to(dtype=torch.float32, device=weight.device)
oc, ic = weight.shape
if subscale is not None: # per-group quantization
subscale = subscale.to(dtype=weight.dtype, device=weight.device)
# region reshape scale and zero point
if scale.numel() == 1:
scale = scale.view(-1).expand(oc)
scale = scale.reshape(oc).contiguous().view(oc, 1)
assert subscale.numel() > 1, "subscale tensor is required for per-group quantization."
subscale = subscale.view(oc, -1, 1).round_()
ng = subscale.shape[1]
gs = ic // ng
assert ic == ng * gs, "input channel size should be divisible by group size."
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1).round_()
# endregion
# region quantize weight tensor
weight = weight.div_(scale).round_()
assert weight.min() >= -128, "first-level quantized weight should be greater than or equal to -128."
assert weight.max() <= 127, "first-level quantized weight should be less than or equal to 127."
weight = weight.view(oc, ng, gs)
if not zero_pre_scaled: # zero point is int8
weight = weight.add_(zero)
weight = weight.div_(subscale)
if zero_pre_scaled: # zero point is int4
if zero.min() < 0: # sint4 zero point
zero = zero.add_(8) # convert to uint4 zero point
assert zero.min() >= 0, "quantized zero point should be non-negative."
assert zero.max() <= 15, "quantized zero point should be less than 16."
weight = weight.add_(zero)
zero = zero.mul_(subscale)
else:
if weight.min() < 0: # sint4 weight
weight = weight.add_(8) # convert to uint4 weight
zero = zero.add_(8 * subscale)
_weight = weight.mul(subscale)
assert _weight.min() >= 0, "first-level dequantize weight should be non-negative."
assert _weight.max() <= 255, "first-level dequantize weight should be less than 256."
del _weight
assert subscale.min() >= 0, "subscale should be non-negative."
assert subscale.max() <= 127, "subscale should be less than or equal to 127."
assert zero.min() >= 0, "quantized zero point should be non-negative."
assert zero.max() <= 255, "quantized zero point should be less than 256."
assert weight.min() >= 0, "quantized weight should be non-negative."
assert weight.max() <= 15, "quantized weight should be less than 16."
# endregion
zero = -zero # ! for group quant, qserve uses q*s+z=r instead of q*s-z=r
subscale = subscale.to(torch.int8)
zero = zero.to(torch.int8)
else: # per-channel quantization
assert subscale is None, "subscale tensor is not required for per-channel quantization."
# region reshape scale and zero point
if scale.numel() == 1:
scale = scale.view(-1).expand(oc)
scale = scale.reshape(oc).contiguous().view(oc, 1)
if zero.numel() == 1:
zero = zero.view(-1).expand(oc)
zero = zero.reshape(oc).contiguous().view(oc, 1)
# endregion
# region quantize weight tensor
if not zero_pre_scaled: # zero point is fp16
weight = weight.add_(zero)
weight = weight.div_(scale).round_()
if zero_pre_scaled: # zero point is int4
zero = zero.round_()
if zero.min() < 0: # sint4 zero point
zero = zero.add_(8) # convert to uint4 zero point
assert zero.min() >= 0, "quantized zero point should be non-negative."
assert zero.max() <= 15, "quantized zero point should be less than 16."
weight = weight.add_(zero)
zero = zero.mul_(scale)
else:
if weight.min() < 0: # sint4 weight
weight = weight.add_(8) # convert to uint4 weight
zero = zero.add_(8 * scale)
assert weight.min() >= 0, "quantized weight should be non-negative."
assert weight.max() <= 15, "quantized weight should be less than 16."
# endregion
zero = zero.to(dtype=dtype)
scale = scale.to(dtype=dtype)
packer = QServePacker()
weight = packer.pack_weight(weight.view(oc, ic).to(torch.uint8))
scale, zero, subscale = packer.pack_scale(scale=scale, zero=zero, subscale=subscale)
return weight, scale, zero, subscale
def convert_to_qserve_w8x8y16_linear_weight(
weight: torch.Tensor, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to QServe W8-X8-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
Returns:
`tuple[torch.Tensor, torch.Tensor]`:
packed quantized weight tensor and scale tensor.
"""
dtype = weight.dtype
assert dtype == torch.float16, "currently qserve only supports fp16."
assert scale is not None, "scale tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=weight.device)
oc = weight.shape[0]
if scale.numel() == 1:
scale = scale.view(-1).expand(oc)
scale = scale.reshape(oc).contiguous().view(oc, 1)
weight = weight.div_(scale).round_()
assert weight.min() >= -128, "quantized weight should be greater than or equal to -128."
assert weight.max() <= 127, "quantized weight should be less than or equal to 127."
weight = weight.contiguous().to(torch.int8)
scale = scale.view(oc).to(dtype=dtype)
return weight, scale
================================================
FILE: deepcompressor/backend/tinychat/__init__.py
================================================
================================================
FILE: deepcompressor/backend/tinychat/convert.py
================================================
# -*- coding: utf-8 -*-
"""QServe state dict converter module."""
import argparse
import os
import safetensors.torch
import torch
import tqdm
from .utils import convert_to_tinychat_w4x16y16_linear_weight
def convert_to_tinychat_w4x16y16_linear_state_dict(
param_name: str,
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
zero_pre_scaled: bool = False,
) -> dict[str, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear state dictionary.
Args:
param_name (`str`):
parameter name.
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`dict[str, torch.Tensor]`:
state dictionary for the quantized weight tensor.
"""
module_name = param_name[:-7]
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight, scale=scale, zero=zero, zero_pre_scaled=zero_pre_scaled
)
state_dict: dict[str, torch.Tensor] = {}
state_dict[f"{module_name}.qweight"] = weight.cpu()
state_dict[f"{module_name}.scales"] = scale.cpu()
state_dict[f"{module_name}.scaled_zeros"] = zero.cpu()
return state_dict
def convert_to_tinychat_state_dict(
state_dict: dict[str, torch.Tensor], scale_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
scales: dict[str, dict[tuple[int, ...], torch.Tensor]] = {}
zeros: dict[str, tuple[torch.Tensor | None, bool]] = {}
print("Loading scale tensors...")
for name, tensor in tqdm.tqdm(scale_dict.items(), desc="Loading scale tensors", leave=False, dynamic_ncols=True):
print(f" - Loading tensor {name} (dtype: {tensor.dtype}, shape: {tensor.shape}, device: {tensor.device})")
if name.endswith("zero"):
# this is a zero point tensor
zero = None if tensor is None or all(t.item() == 0 for t in tensor.flatten()) else tensor
if name.endswith(".scaled_zero"):
zeros[name[:-12]] = (zero, False) # zero point tensor is post-scaled
else:
zeros[name[:-5]] = (zero, True) # zero point tensor is pre-scaled
else:
assert ".weight.scale" in name
# this is a scale tensor
idx = name.index(".weight.scale")
param_name = name[: idx + 7]
scale_level = tuple(map(int, name[idx + 14 :].split(".")))
scales.setdefault(param_name, {})[scale_level] = tensor
for param_name in zeros.keys():
assert param_name in state_dict, f"zero point tensor {param_name} not found in state dict."
assert param_name in scales, f"scale tensor {param_name} not found in scale dict."
converted: dict[str, torch.Tensor] = {}
print("Converting state dict...")
for param_name, param in tqdm.tqdm(state_dict.items(), desc="Converting state dict", dynamic_ncols=True):
if param_name in scales:
print(f" - Converting {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
weight = param.data.clone()
if param_name in zeros:
zero, zero_pre_scaled = zeros[param_name]
zero = zero.clone() if zero is not None else None
else:
zero, zero_pre_scaled = None, False
level_scales = sorted(scales[param_name].items(), key=lambda x: x[0])
assert len(level_scales) == 1, "more than one scale levels are not supported."
scale = level_scales[0][1].clone()
converted.update(
convert_to_tinychat_w4x16y16_linear_state_dict(
param_name, weight, scale=scale, zero=zero, zero_pre_scaled=zero_pre_scaled
)
)
else:
if isinstance(param, torch.Tensor):
print(f" - Copying {param_name} (dtype: {param.dtype}, shape: {param.shape}, device: {param.device})")
converted[param_name] = param.clone().cpu()
else:
print(f" - Copying {param_name} (type: {type(param)}, value: {param})")
converted[param_name] = param
return converted
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--quant-path", type=str, required=True, help="path to the quantization checkpoint directory.")
parser.add_argument("--output-root", type=str, default="", help="root to the output checkpoint directory.")
parser.add_argument("--model-name", type=str, default=None, help="model name.")
parser.add_argument("--model-path", type=str, default=None, help="path to the huggingface model directory.")
parser.add_argument("--copy-on-save", action="store_true", help="copy files on save.")
args = parser.parse_args()
if not args.output_root:
args.output_root = args.quant_path
if args.model_name is None:
assert args.model_path is not None, "model name or path is required."
model_name = args.model_path.rstrip(os.sep).split(os.sep)[-1]
print(f"Model name not provided. Using model name {model_name}.")
else:
model_name = args.model_name
state_dict = torch.load(
os.path.join(args.quant_path, "model.pt"),
map_location="cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu",
)
scale_dict = torch.load(os.path.join(args.quant_path, "scale.pt"), map_location="cpu")
converted = convert_to_tinychat_state_dict(state_dict, scale_dict)
model_name = f"{args.model_name}-w4a16"
output_dirpath = os.path.join(args.output_root, model_name)
os.makedirs(output_dirpath, exist_ok=True)
if args.model_path and os.path.exists(args.model_path):
output_path = os.path.join(output_dirpath, "model.safetensors")
safetensors.torch.save_file(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
for filename in os.listdir(args.model_path):
if filename == "tokenizer.model" or (
filename.endswith(".json") and filename != "pytorch_model.bin.index.json"
):
filepath = os.path.abspath(os.path.join(args.model_path, filename))
if args.copy_on_save:
os.system(f"cp {filepath} {output_dirpath}/")
else:
os.system(f"ln -s {filepath} {output_dirpath}/{filename}")
else:
output_path = os.path.join(output_dirpath, "tinychat-v2.pt")
torch.save(converted, output_path)
print(f"Quantized model checkpoint saved to {output_path}.")
print(f"Quantized model saved to {output_dirpath}.")
================================================
FILE: deepcompressor/backend/tinychat/csrc/load.py
================================================
# -*- coding: utf-8 -*-
"""TinyChat Extension."""
import os
from torch.utils.cpp_extension import load
__all__ = ["_C"]
dirpath = os.path.dirname(__file__)
_C = load(
name="deepcompressor_tinychat_C",
sources=[
f"{dirpath}/pybind.cpp",
f"{dirpath}/quantization/gemv/gemv_cuda.cu",
f"{dirpath}/quantization/gemm/gemm_cuda.cu",
],
extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"],
extra_cuda_cflags=[
"-O3",
"-std=c++20",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--allow-expensive-optimizations=true",
"--threads=8",
],
)
================================================
FILE: deepcompressor/backend/tinychat/csrc/pybind.cpp
================================================
#include
#include
#include "quantization/gemm/gemm_cuda.h"
#include "quantization/gemv/gemv_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("awq_gemv_forward_cuda", &awq_gemv_forward_cuda, "AWQ quantized GEMV kernel.");
}
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/dequantize.cuh
================================================
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
#include
#include
template
__device__ __forceinline__ void dequantize_s4_to_f16x2(T const &source, uint4 *result);
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2(half2 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast(result);
uint32_t const i4s = reinterpret_cast(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
template <>
__device__ __forceinline__ void dequantize_s4_to_f16x2<__nv_bfloat162>(__nv_bfloat162 const &source, uint4 *result)
{
uint32_t *h = reinterpret_cast(result);
uint32_t const source_i4s = reinterpret_cast(source);
// First, we extract the i4s and construct an intermediate bf16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
uint32_t i4s = source_i4s;
// Extract elt_01 - (i4s & 0x000f000f) | 0x43004300
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x43004300
i4s >>= 4;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
// This is the BF16 {-136, -136} represented as an integer.
// static constexpr uint32_t BF16_BIAS = 0xC308C308;
// This is the BF16 {-128, -128} represented as an integer, we do not need to map to [-8, 7]
static constexpr uint32_t NEG_128 = 0xC300C300;
static constexpr uint32_t ONE = 0x3F803F80;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(ONE), "r"(NEG_128));
// Convert elt_23
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE), "r"(NEG_128));
// Convert elt_45
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(ONE), "r"(NEG_128));
// Convert elt_67
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE), "r"(NEG_128));
}
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/gemm/gemm_cuda.cu
================================================
#include
#include
#include "semaphore.h"
#include "gemm_cuda.h"
#include "../dequantize.cuh"
#include "../../utils.cuh"
#include
#include
#define kInterleave 4
#define OP_M 16
#define OP_N 8
#define OP_K 16
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 16
#define WARP_SIZE 32
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 8
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
int num_mn_tiles = (num_in_feats + CTA_M - 1) / CTA_M * (num_out_channels + CTA_N - 1) / CTA_N; \
torch::Tensor _semaphores = torch::empty({num_mn_tiles}, options_int); \
auto semaphores = reinterpret_cast(_semaphores.data_ptr()); \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N / (G / CTA_K) * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + SCALES_SMEM_SIZE) * STAGES * sizeof(f16_t); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize); \
return _out_feats; \
} \
int j_factors1 = num_out_channels / CTA_N / 1; \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template
__inline__ __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
__inline__ __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
template
__device__ void sync_slice(int slice_id)
{
if constexpr (SLICES == 1)
{
__syncthreads();
}
else
{
constexpr int SLICE_GROUP = (SLICES + 7) / 8;
constexpr uint32_t num_threads = NUM_WARPS_MN * WARP_SIZE;
const uint32_t barrier_id = slice_id / SLICE_GROUP + 1;
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "n"(num_threads));
}
}
__inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
template
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same::value || std::is_same::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
template
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr)
{
static_assert(std::is_same::value || std::is_same::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]), "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
: "r"(addr));
}
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
{
const int cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
template
__device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16_t *B_shared_warp);
template <>
__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp)
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
}
template <>
__device__ __inline__ void mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp)
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
}
template
__device__ __inline__ void global_to_share_one_stage_A(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template
__device__ __inline__ void global_to_share_one_stage_B(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template
__device__ __inline__ void global_to_share_one_stage_scales(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int cta_offset_k, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int LD_AMOUNT = (G >= CTA_K) ? CTA_N : CTA_N * CTA_K / G;
constexpr int threads_needed = LD_AMOUNT / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = LD_AMOUNT / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = (cta_offset_k + global_iter_k * CTA_K) / G;
void *dst_ptr = (void *)(dst + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
template
__device__ __inline__ void share_to_reg_one_stage_A(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8 + warp_offset_k;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template
__device__ __inline__ void share_to_reg_one_stage_B(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int warp_offset_k, int k_0_1)
{
using f162_t = typename packed_as::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled + warp_offset_k);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[(warp_offset_k / G) * CTA_N + warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded);
}
}
template
__global__ void gemm_w4a16_T1(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int *__restrict__ semaphores, int M, int N, int K)
{
using f162_t = typename packed_as::type;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES / scales_load_interval * scales_per_load;
constexpr int kSmemSizeZeros = CTA_N * STAGES / scales_load_interval * scales_per_load;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast(mem_shared);
f16_t *B_shared = reinterpret_cast(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
float *C_shared = reinterpret_cast(mem_shared);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int cta_offset_k = blockIdx_z * (K / SPLITK);
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_n = (warp_mn % (CTA_N / WARP_N)) * WARP_N;
int warp_offset_m = (warp_mn / (CTA_N / WARP_N)) * WARP_M;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K / SPLITK;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_B(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, 0, true);
global_to_share_one_stage_scales(
scales, scales_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
share_to_reg_one_stage_B(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, warp_offset_k, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
#pragma unroll
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval * scales_per_load) * CTA_N;
share_to_reg_one_stage_A(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, warp_offset_k, (iter_k + 1) % SHARED_K_ITERS);
}
}
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, cta_offset_k, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales(
scales, scales_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval * scales_per_load) * CTA_N,
N, cta_offset_m, cta_offset_n, cta_offset_k,
k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] += C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2] = C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
if (slice_id == 0)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id] = C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
if (slice_id == 0)
{
Semaphore semaphore(semaphores + blockIdx_y, threadIdx.x);
if constexpr (SPLITK > 1)
{
semaphore.fetch();
}
if (blockIdx_z != 0)
{
semaphore.wait(blockIdx_z);
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
f162_t *existing_psum_ptr = reinterpret_cast(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2);
*existing_psum_ptr = __hadd2(
*existing_psum_ptr,
cuda_cast(*reinterpret_cast(
C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id)));
}
};
}
}
}
else
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
}
if constexpr (SPLITK > 1)
{
int lock = 0;
if (SPLITK == blockIdx_z + 1)
{
lock = 0;
}
else
{
lock = blockIdx_z + 1;
}
semaphore.release(lock);
}
}
}
template
__device__ __inline__ void global_to_share_one_stage_A_T2(f16_t *src, f16_t *dst, int global_nrows, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_M * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int ld_col = (threadIdx.x % threads_per_row);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
}
else
{
if (local_mask & (ld_row + cta_offset_m < global_nrows))
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template
__device__ __inline__ void global_to_share_one_stage_B_T2(f16_t *src, f16_t *dst, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / SHARED_K_ITERS;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = (CTA_N / kInterleave * CTA_K) / PACK_SIZE / threads_used;
constexpr int partial_global_iters = (total_global_iters + SHARED_K_ITERS - 1) / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (threads_used * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
int ld_row = global_iter * cta_step_m_or_n + threadIdx.y * warp_step_m_or_n + (threadIdx.x / threads_per_row);
int ld_col = (threadIdx.x % threads_per_row);
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
*(uint4 *)dst_ptr = *src_ptr;
}
}
}
template
__device__ __inline__ void global_to_share_one_stage_scales_T2(f16_t *src, f16_t *dst, f16_t *src_z, f16_t *dst_z, int global_ncols, int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void *dst_ptr = (void *)(dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
uint32_t addr_z = cast_smem_ptr_to_uint(dst_ptr_z);
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
}
else
{
if (local_mask)
{
*(uint4 *)dst_ptr = *src_ptr;
*(uint4 *)dst_ptr_z = *src_ptr_z;
}
}
}
template
__device__ __inline__ void share_to_reg_one_stage_A_T2(f16_t *src, f16_t *dst, int warp_offset_m, int warp_offset_n, int k_0_1)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * OP_M + (threadIdx.x % 16);
int ld_col = k_0_1 * 16 + (threadIdx.x / 16) * 8;
int ld_col_swizzled = ((ld_col / PACK_SIZE) ^ (ld_row) & 7) * PACK_SIZE;
void *addr_ptr = (void *)(src + ld_row * kSmemCol + ld_col_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template
__device__ __inline__ void share_to_reg_one_stage_B_T2(f16_t *src, f16_t *src_scales, f16_t *src_zeros, f16_t *dst, f16_t *dst_fp16, int warp_offset_m, int warp_offset_n, int k_0_1)
{
using f162_t = typename packed_as::type;
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
int r0 = ((threadIdx.x / 8 / 2) * 8 + threadIdx.x % 8);
int c0 = ((threadIdx.x / 8) % 2) * 8;
int r = r0 / 4;
int c = (r0 % 4) * 16 + c0;
int c_swizzled = ((c / PACK_SIZE) ^ (r % 2) & 7) * PACK_SIZE;
if constexpr (ldmatrix)
{
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
void *addr_ptr = (void *)(src + warp_offset_n / kInterleave * kSmemCol + shared_iter * 16 / kInterleave * kSmemCol + k_0_1 * 16 + r * kSmemCol + c_swizzled);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
f16_t scale = src_scales[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f16_t zero = src_zeros[warp_offset_n + 16 * shared_iter + 8 * (k_0_1 % 2) + threadIdx.x / 4];
f162_t scale2 = f162f162(scale);
f162_t zero2 = f162f162(zero);
f162_t loaded[4];
dequantize_s4_to_f16x2(*reinterpret_cast(dst + (k_0_1 % 2) * 4 + (k_0_1 / 2 * 2) + shared_iter * 8), reinterpret_cast(loaded));
#pragma unroll
for (int i = 0; i < 4; i++)
{
loaded[i] = __hfma2(loaded[i], scale2, zero2);
}
*reinterpret_cast(dst_fp16 + shared_iter * 16 + 8 * (k_0_1 % 2)) = *reinterpret_cast(loaded);
}
}
template
__global__ void gemm_w4a16_T2(f16_t *__restrict__ A, f16_t *__restrict__ B, f16_t *__restrict__ scales, f16_t *__restrict__ zeros, f16_t *__restrict__ C, int M, int N, int K)
{
using f162_t = typename packed_as::type;
constexpr int NUM_WARPS = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % (num_blocks_m * num_blocks_n);
int blockIdx_z = blockIdx.x / (num_blocks_m * num_blocks_n);
const int log_tile = get_log_tile<1>((N + CTA_N - 1) / CTA_N);
int blockIdx_m = blockIdx_y / (num_blocks_n >> log_tile);
int blockIdx_n = blockIdx_y % (num_blocks_n >> log_tile);
const uint2 block_idx_mapping = get_block_idx_mapping(blockIdx_m, blockIdx_n, log_tile);
blockIdx_m = block_idx_mapping.x;
blockIdx_n = block_idx_mapping.y;
float C_warp[CTA_M * CTA_N / CTA_SIZE];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N / kInterleave * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int kSmemSizeScales = CTA_N * STAGES / 2;
constexpr int kSmemSizeZeros = CTA_N * STAGES / 2;
constexpr int scales_load_interval = G / CTA_K;
extern __shared__ half mem_shared[];
f16_t *A_shared = reinterpret_cast(mem_shared);
f16_t *B_shared = reinterpret_cast(mem_shared + kSmemSizeA);
f16_t *scales_shared = reinterpret_cast(mem_shared + kSmemSizeA + kSmemSizeB);
f16_t *zeros_shared = reinterpret_cast(mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales);
f16_t A_shared_warp_[2][WARP_M * INTRIN_K / WARP_SIZE];
f16_t B_shared_warp_[2][WARP_N * 32 / WARP_SIZE];
f16_t B_shared_warp_tmp_[2][WARP_N * 16 / WARP_SIZE];
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_offset_m = (threadIdx.y % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (threadIdx.y / (CTA_M / WARP_M)) * WARP_N;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE; i++)
C_warp[i] = 0.0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A_T2(A, A_shared + k_0_0_ld * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_B_T2(B, B_shared + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_scales_T2(
scales, scales_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
zeros, zeros_shared + (k_0_0_ld / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A_T2(A_shared, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
share_to_reg_one_stage_B_T2(B_shared, scales_shared, zeros_shared, B_shared_warp_tmp_[0], B_shared_warp_[0], warp_offset_m, warp_offset_n, 0);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
f16_t *A_shared_this_compute_stage;
f16_t *B_shared_this_compute_stage;
f16_t *scales_shared_this_compute_stage;
f16_t *zeros_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage;
scales_shared_this_compute_stage = scales_shared + (compute_stage / scales_load_interval) * CTA_N;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage / scales_load_interval) * CTA_N;
share_to_reg_one_stage_A_T2(A_shared_this_compute_stage, A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
if ((iter_k + 1) % kInterleave == 0)
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
else
{
if (compute_stage % 2 == 1)
{
share_to_reg_one_stage_B_T2(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[1], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
else
{
share_to_reg_one_stage_B_T2(
B_shared_this_compute_stage, scales_shared_this_compute_stage, zeros_shared_this_compute_stage,
B_shared_warp_tmp_[0], B_shared_warp_[((iter_k + 1) / 2) % 2],
warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS);
}
}
__syncthreads();
f16_t *A_shared_warp = A_shared_warp_[iter_k % 2];
f16_t *B_shared_warp = B_shared_warp_[(iter_k / 2) % 2];
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4);
mma_m16n8k16(C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4, A_shared_warp + i_0_3 * 8, B_shared_warp + j_0_4 * 16 + (iter_k % 2) * 4 + 8);
}
}
if (iter_k < WARP_K / INTRIN_K - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters);
}
if (iter_k == WARP_K / INTRIN_K - 2)
{
if constexpr (STAGES == 1 && WARP_K / INTRIN_K > 2)
{
__syncthreads();
}
global_to_share_one_stage_A_T2(A, A_shared + ld_stage * kSmemSizeAPerStage, M, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_B_T2(B, B_shared + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_scales_T2(
scales, scales_shared + (ld_stage / scales_load_interval) * CTA_N,
zeros, zeros_shared + (ld_stage / scales_load_interval) * CTA_N,
N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k, k_0_0_ld < gemm_iters && k_0_0_ld % scales_load_interval == 0);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int write_row = cta_offset_m + warp_offset_m + ax0_0_1 * OP_M + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4));
if (write_row < M)
{
*reinterpret_cast(
C + write_row * N +
cta_offset_n + warp_offset_n + ax1_0_1 * 16 +
(local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2) =
cuda_cast(*reinterpret_cast(C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 +
ax1_0_1 * 8 + local_id));
}
};
}
}
}
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros)
{
std::vector output_shape = _in_feats.sizes().vec();
output_shape.back() = _kernel.size(0) * kInterleave;
int num_in_feats = _in_feats.numel() / _in_feats.size(-1);
int num_in_channels = _in_feats.size(-1);
auto options =
torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt32).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
int num_out_feats = _out_feats.numel() / _out_feats.size(-1);
int num_out_channels = _out_feats.size(-1);
if (_in_feats.scalar_type() == at::ScalarType::Half)
{
using f16_t = half;
auto in_feats = reinterpret_cast(_in_feats.data_ptr());
auto kernel = reinterpret_cast(_kernel.data_ptr());
auto scales = reinterpret_cast(_scales.data_ptr());
auto zeros = reinterpret_cast(_zeros.data_ptr());
auto out_feats = reinterpret_cast(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
}
else if (_in_feats.scalar_type() == at::ScalarType::BFloat16)
{
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast(_in_feats.data_ptr());
auto kernel = reinterpret_cast(_kernel.data_ptr());
auto scales = reinterpret_cast(_scales.data_ptr());
auto zeros = reinterpret_cast(_zeros.data_ptr());
auto out_feats = reinterpret_cast(_out_feats.data_ptr());
if (num_out_feats <= 32)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 2;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 64)
{
constexpr int G = 128;
constexpr int CTA_M = 16;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 16;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 128)
{
constexpr int G = 128;
constexpr int CTA_M = 32;
constexpr int CTA_N = 128;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats <= 192)
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int SPLITK = 1;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int G = 128;
constexpr int CTA_M = 64;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N);
constexpr int kSmemByteSize = (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / kInterleave + CTA_N) * STAGES * sizeof(f16_t);
if (kSmemByteSize >= 99 * 1024)
{
printf("This kernel requires %d Bytes of shared memory, which exceeds device limit.\n", kSmemByteSize);
return _out_feats;
}
int j_factors1 = num_out_channels / CTA_N / 1;
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
}
else
{
AT_ERROR("Unsupported input type");
}
return _out_feats;
}
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/gemm/gemm_cuda.h
================================================
#include
torch::Tensor awq_gemm_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scales,
torch::Tensor _zeros);
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/gemm/semaphore.h
================================================
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
// namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
public:
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// } // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/gemv/gemv_cuda.cu
================================================
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include
#include
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#include "../../utils.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
// Reduce sum within the warp using the tree reduction algorithm.
template
__device__ __forceinline__ static void warp_reduce(fp_t *psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = static_cast(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor)
{
return (c + divisor - 1) / divisor;
}
template
__global__ void gemv_kernel(
const f16_t *inputs, const uint32_t *weight, const f16_t *scales, const f16_t *zeros, f16_t *outputs,
const int IC, const int OC)
{
using f162_t = typename packed_as::type;
using accum_t = float;
using accum2_t = typename packed_as::type;
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
alignas(16) f16_t local_inputs[kElemsPerThread];
alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
alignas(16) f16_t half_weight_buffer[kElemsPerThread];
alignas(16) f16_t dequantized_weight[kElemsPerThread * NPerBlock];
alignas(16) f16_t local_scale[NPerBlock];
alignas(16) f16_t local_scaled_zeros[NPerBlock];
accum_t psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = static_cast(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride + (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const f16_t *scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const f16_t *inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4 *)(local_qweights)) =
*((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_f16x2(*reinterpret_cast(local_qweights + i), reinterpret_cast(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
f162_t w =
*reinterpret_cast(
half_weight_buffer + (i + j * kShuffleContinous) * kShuffleBasicTile);
w = __hfma2(w, f162f162(local_scale[idx]), f162f162(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const f16_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
accum2_t prod = cuda_cast(__hmul2(
*reinterpret_cast(dequantized_weight + y * NPerBlock + x * 2),
f162f162(local_inputs[y])));
*reinterpret_cast(psum + batch_idx * NPerBlock + x * 2) = prod + *reinterpret_cast(psum + batch_idx * NPerBlock + x * 2);
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
AT_DISPATCH_REDUCED_FLOATING_TYPES(
_in_feats.scalar_type(),
"awq_gemv_forward_cuda",
[&]
{
using f16_t = typename to_cpp_t::type;
auto in_feats = reinterpret_cast(_in_feats.data_ptr());
auto kernel = reinterpret_cast(_kernel.data_ptr());
auto zeros = reinterpret_cast(_zeros.data_ptr());
auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr());
auto out_feats = reinterpret_cast(_out_feats.data_ptr());
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 2:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 3:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 4:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 5:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 6:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
case 7:
gemv_kernel<<>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
});
return _out_feats;
}
================================================
FILE: deepcompressor/backend/tinychat/csrc/quantization/gemv/gemv_cuda.h
================================================
#pragma once
#include
torch::Tensor awq_gemv_forward_cuda(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size);
================================================
FILE: deepcompressor/backend/tinychat/csrc/utils.cuh
================================================
// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
#pragma once
#include
#include
#include
#include
#include
#include
#define ENABLE_BF16 1
template
struct to_cpp_t;
template <>
struct to_cpp_t
{
using type = half;
};
template <>
struct to_cpp_t
{
using type = __nv_bfloat16;
};
template
struct num_elems;
template <>
struct num_elems
{
static constexpr int value = 1;
};
template <>
struct num_elems
{
static constexpr int value = 2;
};
template <>
struct num_elems
{
static constexpr int value = 4;
};
template <>
struct num_elems
{
static constexpr int value = 1;
};
template <>
struct num_elems
{
static constexpr int value = 2;
};
#ifdef ENABLE_BF16
template <>
struct num_elems<__nv_bfloat16>
{
static constexpr int value = 1;
};
template <>
struct num_elems<__nv_bfloat162>
{
static constexpr int value = 2;
};
#endif
template
struct packed_as;
template
struct packed_as
{
using type = T;
};
template <>
struct packed_as
{
using type = half2;
};
template <>
struct packed_as
{
using type = float2;
};
template <>
struct packed_as
{
using type = int16_t;
};
template <>
struct packed_as
{
using type = int2;
};
template <>
struct packed_as
{
using type = half;
};
template <>
struct packed_as
{
using type = float;
};
#ifdef ENABLE_BF16
template <>
struct packed_as<__nv_bfloat16, 2>
{
using type = __nv_bfloat162;
};
template <>
struct packed_as<__nv_bfloat162, 1>
{
using type = __nv_bfloat16;
};
#endif
#ifdef ENABLE_FP8
template <>
struct packed_as<__nv_fp8_e4m3, 2>
{
using type = __nv_fp8x2_e4m3;
};
template <>
struct packed_as<__nv_fp8x2_e4m3, 1>
{
using type = __nv_fp8_e4m3;
};
template <>
struct packed_as<__nv_fp8_e5m2, 2>
{
using type = __nv_fp8x2_e5m2;
};
template <>
struct packed_as<__nv_fp8x2_e5m2, 1>
{
using type = __nv_fp8_e5m2;
};
#endif
template
__device__ __forceinline__
packed_as::type
f162f162(f16_t x);
template <>
__device__ __forceinline__
packed_as::type
f162f162(half x)
{
return __half2half2(x);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
packed_as<__nv_bfloat16, 2>::type
f162f162<__nv_bfloat16>(__nv_bfloat16 x)
{
return __bfloat162bfloat162(x);
}
# endif
template
__device__ __forceinline__
float2
f1622float2(T val);
template <>
__device__ __forceinline__
float2
f1622float2(half2 val)
{
return __half22float2(val);
}
#ifdef ENABLE_BF16
template <>
__device__ __forceinline__
float2
f1622float2<__nv_bfloat162>(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
# endif
inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); }
inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); }
inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); }
static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast(dst);
}
template
inline __device__ T ldg(const T *val)
{
return __ldg(val);
}
#if ENABLE_BF16
#define float22bf162 __float22bfloat162_rn
inline __device__ int16_t bf1622int16(__nv_bfloat162 val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2 f_val;
f_val.x = max(min(__low2float(val), 127.f), -128.f);
f_val.y = max(min(__high2float(val), 127.f), -128.f);
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast(static_cast(f_val.x));
int8[1] = static_cast(static_cast(f_val.y));
return int16;
#else
val = __hmin2(val, make_bfloat162(127., 127.));
val = __hmax2(val, make_bfloat162(-128., -128.));
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = static_cast(static_cast(val.x));
int8[1] = static_cast(static_cast(val.y));
return int16;
#endif
}
#endif
#if ENABLE_BF16
template <>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
template <>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16 *val)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0];
#else
return __ldg(val);
#endif
}
#endif // ENABLE_BF16
template
__device__ inline T_OUT cuda_cast(T_IN val)
{
return val;
}
template <>
__device__ inline float2 cuda_cast(int2 val)
{
return make_float2(val.x, val.y);
}
template <>
__device__ inline float2 cuda_cast(float val)
{
return make_float2(val, val);
}
template <>
__device__ inline float2 cuda_cast(half2 val)
{
return __half22float2(val);
}
template <>
__device__ inline half2 cuda_cast(float2 val)
{
return __float22half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast(float val)
{
return __float2half2_rn(val);
}
template <>
__device__ inline half2 cuda_cast(half val)
{
return __half2half2(val);
}
template <>
__device__ inline int8_t cuda_cast(half val)
{
union
{
int8_t int8[2];
int16_t int16;
};
union
{
half fp16;
int16_t int16_in;
};
fp16 = val;
asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast(half2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast(val.x);
int8[1] = cuda_cast(val.y);
return int16;
}
template <>
__device__ inline int8_t cuda_cast(float val)
{
union
{
int8_t int8[2];
int16_t int16;
};
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
return int8[0];
}
template <>
__device__ inline int16_t cuda_cast(float2 val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int8[0] = cuda_cast(val.x);
int8[1] = cuda_cast(val.y);
return int16;
}
template <>
__device__ inline half2 cuda_cast(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_half2(int8[0], int8[1]);
}
template <>
__device__ inline float2 cuda_cast(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
return make_float2(int8[0], int8[1]);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_cast(int32_t val)
{
return static_cast(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast(int8_t val)
{
return static_cast(val);
}
template <>
__device__ inline int8_t cuda_cast(__nv_bfloat16 val)
{
return static_cast(val);
}
template <>
__device__ inline float cuda_cast(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
template <>
__device__ inline float2 cuda_cast(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
template <>
__device__ inline half cuda_cast(__nv_bfloat16 val)
{
return __float2half(__bfloat162float(val));
}
template <>
__device__ inline int16_t cuda_cast(__nv_bfloat162 val)
{
return bf1622int16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val)
{
return __float2bfloat16(val);
}
template <>
__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val)
{
return __float2bfloat16(__half2float(val));
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val)
{
return __bfloat162bfloat162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val)
{
return __float2bfloat162_rn(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val)
{
return float22bf162(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val)
{
union
{
int8_t int8[2];
int16_t int16;
};
int16 = val;
__nv_bfloat162 res;
res.x = cuda_cast<__nv_bfloat16>(int8[0]);
res.y = cuda_cast<__nv_bfloat16>(int8[1]);
return res;
}
template <>
__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val)
{
return float22bf162(__half22float2(val));
}
#endif // ENABLE BF16
template
__device__ inline To cuda_sum(Ti val)
{
return cuda_cast(val);
};
template
__device__ inline To cuda_sum(float2 val)
{
return cuda_cast(val.x + val.y);
};
// Unary maximum: compute the max of a vector type
template
__device__ inline To cuda_max(Ti val)
{
return cuda_cast(val);
};
template <>
__device__ inline float cuda_max(float2 val)
{
return fmaxf(val.x, val.y);
}
template <>
__device__ inline half cuda_max(half2 val)
{
return __hmax(val.x, val.y);
}
#ifdef ENABLE_BF16
template <>
__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmax(val.x, val.y);
#endif
}
#endif
// Binary maximum: compute the max of two scalar types
template
__device__ inline T cuda_max(T val1, T val2)
{
return (val1 > val2) ? val1 : val2;
}
template
__device__ inline T cuda_abs(T val)
{
assert(false);
return {};
}
template <>
__device__ inline float cuda_abs(float val)
{
return fabs(val);
}
template <>
__device__ inline float2 cuda_abs(float2 val)
{
return make_float2(fabs(val.x), fabs(val.y));
}
template <>
__device__ inline half cuda_abs(half val)
{
return __habs(val);
}
template <>
__device__ inline half2 cuda_abs(half2 val)
{
return __habs2(val);
}
#ifdef ENABLE_BF16
#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)
template <>
__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val)
{
return __habs(val);
}
template <>
__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val)
{
return __habs2(val);
}
#endif
#endif
================================================
FILE: deepcompressor/backend/tinychat/linear.py
================================================
# -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module"""
import warnings
import torch
import torch.nn as nn
from .csrc.load import _C
from .utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"]
warnings.warn(
"Module `tinychat.linear` will be moved to `Nunchaku` and deprecated in the future release.",
DeprecationWarning,
stacklevel=2,
)
class W4Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
group_size: int = 128,
dtype: torch.dtype = torch.float16,
device: str | torch.device = "cuda",
):
super().__init__()
assert dtype in (torch.float16, torch.bfloat16), f"Unsupported dtype: {dtype}"
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.weight_bits) == 0
self.ceil_num_groups = ceil_num_groups(
in_features=self.in_features,
group_size=self.group_size,
weight_bits=self.weight_bits,
)
assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
self.out_features // self.interleave,
self.in_features // (16 // self.weight_bits) * self.interleave,
),
dtype=torch.int16,
device=device,
),
)
self.register_buffer(
"scales",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
self.register_buffer(
"scaled_zeros",
torch.zeros((self.ceil_num_groups, self.out_features), dtype=dtype, device=device),
)
if bias:
self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
else:
self.bias = None
@property
def weight_bits(self) -> int:
return 4
@property
def interleave(self) -> int:
return 4
@torch.no_grad()
def forward(self, x):
if x.numel() / x.shape[-1] < 8:
out = _C.awq_gemv_forward_cuda(
x,
self.qweight,
self.scales,
self.scaled_zeros,
x.numel() // x.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = _C.awq_gemm_forward_cuda(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out
return out
@staticmethod
def from_linear(
linear: nn.Linear,
group_size: int,
init_only: bool = False,
weight: torch.Tensor | None = None,
scale: torch.Tensor | None = None,
zero: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> "W4Linear":
"""Convert a linear layer to a TinyChat 4-bit weight-only quantized linear layer.
Args:
linear (`nn.Linear`):
linear layer to be converted.
group_size (`int`):
quantization group size.
init_only (`bool`, *optional*, defaults to `False`):
whether to only initialize the quantized linear layer.
weight (`torch.Tensor`, *optional*, defaults to `None`):
weight tensor for the quantized linear layer.
scale (`torch.Tensor`, *optional*, defaults to `None`):
scale tensor for the quantized linear layer.
zero (`torch.Tensor`, *optional*, defaults to `None`):
zero point tensor for the quantized linear layer.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`W4Linear`:
quantized linear layer.
"""
assert isinstance(linear, nn.Linear)
weight = linear.weight.data if weight is None else weight.data
dtype, device = weight.dtype, weight.device
oc, ic = linear.out_features, linear.in_features
_linear = W4Linear(
in_features=ic,
out_features=oc,
bias=linear.bias is not None,
group_size=group_size,
dtype=dtype,
device=device,
)
if init_only:
return _linear
if linear.bias is not None:
_linear.bias.data.copy_(linear.bias.data)
if scale is None:
assert zero is None, "scale and zero point tensors should be provided together."
group_size = ic if group_size <= 0 else group_size
assert group_size <= ic, "group size should be less than or equal to input channel size."
assert ic % group_size == 0, "input channel size should be divisible by group size."
ng, gs = ic // group_size, group_size
weight = weight.to(dtype=torch.float32).view(oc, 1, ng, gs)
vmin, vmax = weight.amin(dim=-1, keepdim=True), weight.amax(dim=-1, keepdim=True)
scale = (vmax - vmin).div_(15)
scale[scale == 0] = 1.0
if zero_pre_scaled:
zero = vmin.neg_().div_(scale).round_().clamp_(0, 15)
weight = weight.div_(scale).add_(zero).round_().clamp_(0, 15).sub_(zero).mul_(scale)
else:
zero = vmin.neg_().clamp_min(0)
weight = weight.add_(zero).div_(scale).round_().clamp_(0, 15).mul_(scale).sub_(zero)
weight = weight.to(dtype=dtype).view(oc, ic)
scale = scale.to(dtype=dtype)
zero = zero.to(dtype=dtype)
weight, scale, zero = convert_to_tinychat_w4x16y16_linear_weight(
weight=weight,
scale=scale,
zero=zero,
zero_pre_scaled=zero_pre_scaled,
)
_linear.qweight.data.copy_(weight)
_linear.scales.data.copy_(scale)
_linear.scaled_zeros.data.copy_(zero)
return _linear
def extra_repr(self) -> str:
return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.weight_bits,
self.group_size,
)
================================================
FILE: deepcompressor/backend/tinychat/utils.py
================================================
# -*- coding: utf-8 -*-
"""TinyChat backend utilities."""
import torch
from ..utils import ceil_divide
__all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
assert weight_bits in (4, 2, 1), "weight bits should be 4, 2, or 1."
pack_size = 32 // weight_bits # one INT32 contains `pack_size` elements of weights
num_packs = ceil_divide(num_groups, pack_size)
if group_size >= 128:
num_packs_factor = 1
elif group_size == 64:
num_packs_factor = 2
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
return num_groups
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
# [0, 1, ..., 31] -> [0, 8, 16, 24, 1, 9, 17, 25, ..., 7, 15, 23, 31]
weight = weight.view(-1, 4, 8)
weight = weight[:, 0] | (weight[:, 1] << 4) | (weight[:, 2] << 8) | (weight[:, 3] << 12)
weight = weight.view(oc // 4, 4, ic // 64, 16).permute(0, 2, 1, 3).reshape(oc // 4, ic)
return weight.to(torch.int16)
def convert_to_tinychat_w4x16y16_linear_weight(
weight: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
assert scale is not None, "scale tensor is required for quantization."
assert zero is not None, "zero point tensor is required for quantization."
weight = weight.to(dtype=torch.float32)
scale = scale.to(dtype=torch.float32, device=device)
zero = zero.to(dtype=torch.float32, device=device)
if zero_pre_scaled:
zero = zero * scale
oc, ic = weight.shape
if scale.numel() == 1:
scale = scale.view(1, 1).expand(oc, 1)
ng, gs = 1, ic
else:
ng = scale.numel() // oc
gs = ic // ng
scale = scale.reshape(oc, ng).contiguous().view(oc, ng, 1)
assert ic == gs * ng, "input channel size should be equal to group size times number of groups."
if zero.numel() == 1:
zero = zero.view(1, 1).expand(oc, ng)
zero = zero.reshape(oc, ng).contiguous().view(oc, ng, 1)
weight = weight.view(oc, ng, -1).add_(zero).div_(scale).round_().view(oc, ic)
assert weight.min() >= 0 and weight.max() <= 15, "quantized weight should be in [0, 15]."
_weight = pack_w4(weight.to(torch.int32))
_ng = ceil_num_groups(ic, gs, weight_bits=4)
_scale = torch.zeros((_ng, oc), dtype=dtype, device=device)
_zero = torch.zeros((_ng, oc), dtype=dtype, device=device)
_scale[:ng] = scale.view(oc, ng).t().to(dtype=dtype)
_zero[:ng] = zero.view(oc, ng).t().to(dtype=dtype).neg_()
return _weight, _scale, _zero
================================================
FILE: deepcompressor/backend/utils.py
================================================
# -*- coding: utf-8 -*-
"""Backend utilities."""
import typing as tp
import safetensors
import torch
__all__ = ["ceil_divide", "pad", "fp_quantize", "MmaWeightPackerBase"]
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
return (x + divisor - 1) // divisor
def pad(
tensor: tp.Optional[torch.Tensor],
divisor: int | tp.Sequence[int],
dim: int | tp.Sequence[int],
fill_value: float | int = 0,
) -> torch.Tensor:
if isinstance(divisor, int):
if divisor <= 1:
return tensor
elif all(d <= 1 for d in divisor):
return tensor
if tensor is None:
return None
shape = list(tensor.shape)
if isinstance(dim, int):
assert isinstance(divisor, int)
shape[dim] = ceil_divide(shape[dim], divisor) * divisor
else:
if isinstance(divisor, int):
divisor = [divisor] * len(dim)
for d, div in zip(dim, divisor, strict=True):
shape[d] = ceil_divide(shape[d], div) * div
result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
result[[slice(0, extent) for extent in tensor.shape]] = tensor
return result
def load_state_dict_in_safetensors(
path: str, device: str | torch.device = "cpu", filter_prefix: str = ""
) -> dict[str, torch.Tensor]:
"""Load state dict in SafeTensors.
Args:
path (`str`):
file path.
device (`str` | `torch.device`, optional, defaults to `"cpu"`):
device.
filter_prefix (`str`, optional, defaults to `""`):
filter prefix.
Returns:
`dict`:
loaded SafeTensors.
"""
state_dict = {}
with safetensors.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def fp_quantize(x: torch.Tensor, codebook: torch.Tensor | None = None) -> torch.Tensor:
if codebook is None:
codebook = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
dtype=x.dtype,
device=x.device,
)
return (x.unsqueeze(-1) - codebook.unsqueeze(0)).abs().argmin(dim=-1)
class MmaWeightPackerBase:
def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
self.bits = bits
assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
# region compute tile size
self.comp_n = comp_n if comp_n is not None else 16
"""smallest tile size in `n` dimension for MMA computation."""
self.comp_k = comp_k if comp_k is not None else 256 // self.bits
"""smallest tile size in `k` dimension for MMA computation."""
# the smallest MMA computation may contain several MMA instructions
self.insn_n = 8 # mma instruction tile size in `n` dimension
"""tile size in `n` dimension for MMA instruction."""
self.insn_k = self.comp_k
"""tile size in `k` dimension for MMA instruction."""
assert self.insn_k * self.bits in (128, 256), (
f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
)
assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
self.num_lanes = 32
"""there are 32 lanes (or threds) in a warp."""
self.num_k_lanes = 4
self.num_n_lanes = 8
assert warp_n >= self.comp_n and warp_n % self.comp_n == 0, (
f"warp_n ({warp_n}) should be divisible by comp_n({self.comp_n})."
)
self.warp_n = warp_n
# endregion
# region memory
self.reg_k = 32 // self.bits
"""number of elements in a register in `k` dimension."""
self.reg_n = 1
"""number of elements in a register in `n` dimension (always 1)."""
self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
"""number of elements in a pack in `k` dimension."""
self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
"""number of elements in a pack in `n` dimension."""
self.pack_size = self.k_pack_size * self.n_pack_size
"""number of elements in a pack accessed by a lane at a time."""
assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
self.mem_k = self.comp_k
"""the tile size in `k` dimension for one tensor memory access."""
self.mem_n = warp_n
"""the tile size in `n` dimension for one tensor memory access."""
self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
"""number of packs in `k` dimension for one tensor memory access."""
self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
"""number of packs in `n` dimension for one tensor memory access."""
# endregion
def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
return (
n // self.mem_n,
self.num_n_packs,
self.n_pack_size,
self.num_n_lanes,
self.reg_n,
k // self.mem_k,
self.num_k_packs,
self.k_pack_size,
self.num_k_lanes,
self.reg_k,
)
================================================
FILE: deepcompressor/calib/__init__.py
================================================
# -*- coding: utf-8 -*-
================================================
FILE: deepcompressor/calib/config/__init__.py
================================================
# -*- coding: utf-8 -*-
from .lowrank import QuantLowRankCalibConfig, SkipBasedQuantLowRankCalibConfig
from .range import DynamicRangeCalibConfig, SkipBasedDynamicRangeCalibConfig
from .reorder import ChannelOrderCalibConfig, SkipBasedChannelOrderConfig
from .rotation import QuantRotationConfig
from .search import (
SearchBasedCalibConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
)
from .smooth import SkipBasedSmoothCalibConfig, SmoothCalibConfig, SmoothSpanMode, SmoothTransfomerConfig
================================================
FILE: deepcompressor/calib/config/lowrank.py
================================================
# -*- coding: utf-8 -*-
"""Quantization SVD calibration configuration."""
from dataclasses import dataclass, field
from omniconfig import configclass
from ...quantizer.config import QuantLowRankConfig
from ...utils.common import num2str
from ...utils.config import SkipBasedConfig
from .search import SearchBasedCalibConfig, SearchBasedCalibGranularity, SearchBasedCalibStrategy
__all__ = ["QuantLowRankCalibConfig", "SkipBasedQuantLowRankCalibConfig"]
@configclass
@dataclass
class QuantLowRankCalibConfig(SearchBasedCalibConfig, QuantLowRankConfig):
"""Configuration for quantization low-rank branch calibration.
Args:
rank (`int`, *optional*, defaults to `32`):
The rank of the low-rank branch.
exclusive (`bool`, *optional*, defaults to `False`):
Whether to use exclusive low-rank branch for each weight sharing the inputs.
compensate (`bool`, *optional*, defaults to `False`):
Whether the low-rank branch compensates the quantization error.
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
num_iters (`int`, *optional*, default=`1`):
The number of iterations.
early_stop (`bool`, *optional*, default=`False`):
Whether to stop the calibration early.
"""
granularity: SearchBasedCalibGranularity = field(init=False, default=SearchBasedCalibGranularity.Layer)
element_batch_size: int = field(init=False, default=-1)
element_size: int = field(init=False, default=-1)
pre_reshape: bool = field(init=False, default=True)
num_iters: int = 1
early_stop: bool = False
def __post_init__(self):
if self.strategy != SearchBasedCalibStrategy.Manual:
self.strategy = SearchBasedCalibStrategy.GridSearch
if self.compensate and self.num_iters <= 1:
self.exclusive = True
super().__post_init__()
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Generate the directory names of the configuration.
Returns:
list[str]: The directory names.
"""
names = super().generate_dirnames(**kwargs)
name = f"i{num2str(self.num_iters)}.r{num2str(self.rank)}"
if self.exclusive:
name += ".exclusive"
if self.compensate:
name += ".compensate"
if self.early_stop and self.num_iters > 1:
name += ".earlystop"
names.append(name)
if prefix:
names = [f"{prefix}.{name}" for name in names]
return names
@configclass
@dataclass
class SkipBasedQuantLowRankCalibConfig(SkipBasedConfig, QuantLowRankCalibConfig):
"""Configuration for Quantization Low-Rank Branch calibration.
Args:
rank (`int`, *optional*, defaults to `32`):
The rank of the low-rank branch.
exclusive (`bool`, *optional*, defaults to `False`):
Whether to use exclusive low-rank branch for each weight sharing the inputs.
compensate (`bool`, *optional*, defaults to `False`):
Whether the low-rank branch compensates the quantization error.
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
num_iters (`int`, *optional*, default=`1`):
The number of iterations.
early_stop (`bool`, *optional*, default=`False`):
Whether to stop the calibration early.
skips (`list[str]`, *optional*, default=`[]`):
The keys of the modules to skip.
"""
pass
================================================
FILE: deepcompressor/calib/config/range.py
================================================
# -*- coding: utf-8 -*-
"""Quantization dynamic range calibration configuration."""
from dataclasses import dataclass
from omniconfig import configclass
from ...utils.common import num2str
from ...utils.config import SkipBasedConfig
from .search import SearchBasedCalibConfig, SearchBasedCalibStrategy
__all__ = ["DynamicRangeCalibConfig", "SkipBasedDynamicRangeCalibConfig"]
@configclass
@dataclass
class DynamicRangeCalibConfig(SearchBasedCalibConfig):
"""Configuration for quantization dynamic range calibration.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`):
The granularity for quantization calibration.
element_batch_size (`int`, *optional*, default=`-1`):
The element batch size for calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
element_size (`int`, *optional*, default=`-1`):
The calibration element size.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
pre_reshape (`bool`, *optional*, default=`True`):
Whether to enable reshaping the tensor before calibration.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
ratio (`float`, *optional*, default=`1.0`):
The dynamic range ratio.
max_shrink (`float`, *optional*, default=`0.2`):
Maximum shrinkage ratio.
max_expand (`float`, *optional*, default=`1.0`):
Maximum expansion ratio.
num_grids (`int`, *optional*, default=`80`):
Number of grids for linear range search.
allow_scale (`bool`, *optional*, default=`False`):
Whether to allow range dynamic scaling.
"""
ratio: float = 1.0
max_shrink: float = 0.2
max_expand: float = 1.0
num_grids: int = 80
allow_scale: bool = False
def get_linear_ratios(self) -> list[float]:
"""Get the ratios for linear range search.
Returns:
`list[float]`:
The dynamic range ratio candidates for linear range search.
"""
num_grids, max_shrink, max_expand = self.num_grids, self.max_shrink, self.max_expand
assert max_shrink < 1, "maximal shrinkage ratio must be less than 1"
ratios = [1 - grid / num_grids * (1 - max_shrink) for grid in range(1, num_grids + 1)]
if max_expand > 1:
ratios += [1 + grid / num_grids * (max_expand - 1) for grid in range(1, num_grids + 1)]
return ratios
def get_ratios(self) -> list[list[float]]:
"""Get the ratios for linear range search.
Returns:
`list[list[float]]`:
The dynamic range ratio candidates for linear range search.
"""
if self.strategy == SearchBasedCalibStrategy.Manual:
return [[self.ratio]]
elif self.strategy == SearchBasedCalibStrategy.GridSearch:
return [[1.0], self.get_linear_ratios()]
else:
raise ValueError(f"Invalid strategy: {self.strategy}")
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Generate the directory names of the configuration.
Args:
prefix (`str`, *optional*, default=`""`):
The prefix of the directory.
Returns:
`list[str]`:
The directory names.
"""
names = super().generate_dirnames(**kwargs)
if self.strategy == SearchBasedCalibStrategy.Manual:
name = f"r.[{num2str(self.ratio)}]"
elif self.strategy == SearchBasedCalibStrategy.GridSearch:
name = f"r.[{num2str(self.max_shrink)}.{num2str(self.max_expand)}].g{self.num_grids}"
else:
raise ValueError(f"Invalid strategy: {self.strategy}")
if self.allow_scale:
name += ".scale"
names.append(name)
if prefix:
names = [f"{prefix}.{name}" for name in names]
return names
@configclass
@dataclass
class SkipBasedDynamicRangeCalibConfig(SkipBasedConfig, DynamicRangeCalibConfig):
"""Configuration for quantization dynamic range calibration.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`):
The granularity for quantization calibration.
element_batch_size (`int`, *optional*, default=`-1`):
The element batch size for calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
element_size (`int`, *optional*, default=`-1`):
The calibration element size.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
pre_reshape (`bool`, *optional*, default=`True`):
Whether to enable reshaping the tensor before calibration.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
ratio (`float`, *optional*, default=`1.0`):
The dynamic range ratio.
max_shrink (`float`, *optional*, default=`0.2`):
Maximum shrinkage ratio.
max_expand (`float`, *optional*, default=`1.0`):
Maximum expansion ratio.
num_grids (`int`, *optional*, default=`80`):
Number of grids for linear range search.
allow_scale (`bool`, *optional*, default=`False`):
Whether to allow range dynamic scaling.
skips (`list[str]`, *optional*, default=`[]`):
The keys of the modules to skip.
"""
pass
================================================
FILE: deepcompressor/calib/config/reorder.py
================================================
# -*- coding: utf-8 -*-
"""Channel reorder configuration."""
import enum
from dataclasses import dataclass, field
from omniconfig import configclass
from ...utils.config import SkipBasedConfig
from .search import (
SearchBasedCalibConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
)
__all__ = ["ChannelOrderCalibConfig", "SkipBasedChannelOrderConfig"]
@configclass
@dataclass
class ChannelOrderCalibConfig(SearchBasedCalibConfig):
"""Configuration for channel order calibration in group quantization.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
allow_x_quant (`bool`, *optional*, default=`True`):
Whether to allow input quantization during calibration.
allow_w_quant (`bool`, *optional*, default=`True`):
Whether to allow weight quantization during calibration.
channel_metric (`ChannelMetricMode`, *optional*, default=`ChannelMetricMode.AbsNormalizedMean`):
The mode for computing the channel importance.
channel_index (`ChannelIndexMode`, *optional*, default=`ChannelIndexMode.Sequential`):
The mode for ranking the channel importance.
dynamic (`bool`, *optional*, default=`False`):
Whether to enable dynamic channel reorder.
"""
class ChannelMetric(enum.Enum):
"""The mode for computing the channel importance."""
InputsAbsMax = "xMax"
InputsAbsMean = "xAvg"
InputsRootMeanSquare = "xRms"
WeightsAbsMax = "wMax"
WeightsAbsMean = "wAvg"
WeightsRootMeanSquare = "wRms"
AbsMaxProduct = "pMax"
AbsMeanProduct = "pAvg"
RootMeanSquareProduct = "pRms"
class ChannelIndex(enum.Enum):
"""The mode for ranking the channel importance."""
Sequential = "Seq"
Transpose = "Trp"
objective: SearchBasedCalibObjective = field(init=False, default=SearchBasedCalibObjective.OutputsError)
granularity: SearchBasedCalibGranularity = field(init=False, default=SearchBasedCalibGranularity.Layer)
element_batch_size: int = field(init=False, default=-1)
element_size: int = field(init=False, default=-1)
pre_reshape: bool = field(init=False, default=True)
allow_x_quant: bool = True
allow_w_quant: bool = True
channel_metric: ChannelMetric = ChannelMetric.InputsAbsMax
channel_index: ChannelIndex = ChannelIndex.Sequential
dynamic: bool = False
def __post_init__(self) -> None:
if self.strategy != SearchBasedCalibStrategy.Manual:
self.strategy = SearchBasedCalibStrategy.GridSearch
super().__post_init__()
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Generate the directory names of the configuration.
Args:
prefix (`str`, *optional*, default=`""`):
The prefix of the directory.
Returns:
`list[str]`:
The directory names.
"""
names = super().generate_dirnames(**kwargs)
if self.strategy == SearchBasedCalibStrategy.Manual:
name = f"{self.channel_metric.name}.{self.channel_index.name}"
else:
name = "search"
if self.dynamic:
name += ".dynamic"
names.append(name)
disallows = []
if not self.allow_x_quant:
disallows.append("x")
if not self.allow_w_quant:
disallows.append("w")
if disallows:
names.append(f"disallow.[{'+'.join(disallows)}]")
if prefix:
names = [f"{prefix}.{name}" for name in names]
return names
@configclass
@dataclass
class SkipBasedChannelOrderConfig(SkipBasedConfig, ChannelOrderCalibConfig):
"""Configuration for channel order calibration in group quantization.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
allow_x_quant (`bool`, *optional*, default=`True`):
Whether to allow input quantization during calibration.
allow_w_quant (`bool`, *optional*, default=`True`):
Whether to allow weight quantization during calibration.
channel_metric (`ChannelMetricMode`, *optional*, default=`ChannelMetricMode.AbsNormalizedMean`):
The mode for computing the channel importance.
channel_index (`ChannelIndexMode`, *optional*, default=`ChannelIndexMode.Sequential`):
The mode for ranking the channel importance.
dynamic (`bool`, *optional*, default=`False`):
Whether to enable dynamic channel reorder.
skips (`list[str]`, *optional*, default=`[]`):
The keys of the modules to skip.
"""
pass
================================================
FILE: deepcompressor/calib/config/rotation.py
================================================
# -*- coding: utf-8 -*-
"""Quantization Rotation configuration."""
import os
import typing as tp
from dataclasses import dataclass, field
import omniconfig
from omniconfig import configclass
__all__ = ["QuantRotationConfig"]
@configclass
@dataclass
class QuantRotationConfig:
"""Configuration for rotation quantization.
Args:
name (`str`):
The name of the rotation quantization configuration. If `path` is provided, this is required.
Otherwise, it is set to "random" if `random` is `True`, and "hadamard" otherwise.
path (`str`, *optional*, default=`""`):
The path to the rotation matrix. If provided, `name` must be set.
random (`bool`, *optional*, default=`False`):
Whether to use random hadamard sample as rotation matrix.
transforms (`list[str]`, *optional*, default=`[]`):
The module keys using explicit hadamard transform.
"""
name: str = ""
path: str = ""
random: bool = False
transforms: list[str] = field(default_factory=list)
def __post_init__(self) -> None:
self.transforms = sorted(set(self.transforms or []))
if self.path and os.path.exists(self.path):
assert self.name, "The name of the rotation quantization configuration must be provided."
self.random = False
else:
self.path = ""
self.name = "random" if self.random else "hadamard"
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Get the directory names of the rotation quantization configuration.
Returns:
list[str]: The directory names of the rotation quantization configuration.
"""
name = self.name
if self.transforms:
name += f".[{'+'.join(self.transforms)}]"
return [f"{prefix}.{name}" if prefix else name]
@classmethod
def update_get_arguments(
cls: type["QuantRotationConfig"],
*,
overwrites: dict[str, tp.Callable[[omniconfig.Arguments], None] | None] | None = None,
defaults: dict[str, tp.Any] | None = None,
) -> tuple[dict[str, tp.Callable[[omniconfig.Arguments], None] | None], dict[str, tp.Any]]:
"""Get the arguments for the rotation quantization configuration."""
overwrites = overwrites or {}
defaults = defaults or {}
collect_fn = omniconfig.ADD_PREFIX_BOOL_FIELDS("transform", **defaults)
def add_transforms_argument(parser):
collect_fn(parser)
parser.add_argument("--transforms", nargs="+", default=[], help="The keys of the modules to transform.")
overwrites.setdefault("transforms", add_transforms_argument)
return overwrites, defaults
@classmethod
def update_from_dict(
cls: type["QuantRotationConfig"], *, parsed_args: dict[str, tp.Any], overwrites: dict[str, tp.Any]
) -> tuple[dict[str, tp.Any], dict[str, tp.Any]]:
"""Create a rotation quantization configuration from the parsed arguments."""
parsed_args.setdefault("transforms", []).extend(omniconfig.COLLECT_PREFIX_BOOL_FIELDS(parsed_args, "transform"))
return parsed_args, overwrites
================================================
FILE: deepcompressor/calib/config/search.py
================================================
# -*- coding: utf-8 -*-
"""Quantization calibrator configurations."""
import enum
from dataclasses import dataclass
from omniconfig import configclass
from ...utils.common import num2str
__all__ = [
"SearchBasedCalibStrategy",
"SearchBasedCalibGranularity",
"SearchBasedCalibObjective",
"SearchBasedCalibConfig",
]
class SearchBasedCalibStrategy(enum.Enum):
"""The strategy for search-based quantization calibration."""
Manual = enum.auto()
GridSearch = enum.auto()
# RandomSearch = enum.auto()
# Bayesian = enum.auto()
# EvolutionaryAlgorithm = enum.auto()
# EvolutionaryStrategy = enum.auto()
class SearchBasedCalibGranularity(enum.Enum):
"""The granularity for search-based quantization calibration."""
Group = enum.auto()
ChannelGroup = enum.auto()
Layer = enum.auto()
class SearchBasedCalibObjective(enum.Enum):
"""The objective for search-based quantization calibration."""
TensorError = enum.auto()
"""minimize the quantization error of the tensor."""
ProductsError = enum.auto()
"""minimize the error of the the multiplication products."""
OutputsError = enum.auto()
"""minimize the error of the outputs of the evaluation module."""
@configclass
@dataclass
class SearchBasedCalibConfig:
"""The base configuration for search-based quantization calibration.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`):
The granularity for quantization calibration.
element_batch_size (`int`, *optional*, default=`-1`):
The element batch size for calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
element_size (`int`, *optional*, default=`-1`):
The calibration element size.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
pre_reshape (`bool`, *optional*, default=`True`):
Whether to enable reshaping the tensor before calibration.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
"""
degree: int = 2
objective: SearchBasedCalibObjective = SearchBasedCalibObjective.OutputsError
strategy: SearchBasedCalibStrategy = SearchBasedCalibStrategy.Manual
granularity: SearchBasedCalibGranularity = SearchBasedCalibGranularity.Layer
element_batch_size: int = -1
sample_batch_size: int = -1
element_size: int = -1
sample_size: int = -1
pre_reshape: bool = True
outputs_device: str = "cpu"
def __post_init__(self) -> None:
if self.outputs_device != "cpu":
self.outputs_device = None
if self.element_size != 0 or self.sample_size != 0:
assert self.element_batch_size != 0, "element_batch_size must not be zero"
assert self.sample_batch_size != 0, "sample_batch_size must not be zero"
assert self.element_size != 0, "element_size must not be zero"
assert self.sample_size != 0, "sample_size must not be zero"
else:
assert self.objective == SearchBasedCalibObjective.TensorError
if self.objective == SearchBasedCalibObjective.TensorError:
pass
elif self.granularity == SearchBasedCalibGranularity.Layer:
self.objective = SearchBasedCalibObjective.OutputsError
self.element_batch_size = -1
self.element_size = -1
@property
def needs_search(self) -> bool:
"""Whether the search is enabled."""
return self.strategy != SearchBasedCalibStrategy.Manual
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Generate the directory names of the configuration.
Args:
prefix (`str`, *optional*, default=`""`):
The prefix of the directory.
Returns:
`list[str]`:
The directory names.
"""
name = f"{self.objective.name}.{self.strategy.name}.{self.granularity.name}.d{num2str(self.degree)}"
name += f".e{num2str(self.element_size)}.s{num2str(self.sample_size)}"
if prefix:
name = f"{prefix}.{name}"
return [name]
================================================
FILE: deepcompressor/calib/config/smooth.py
================================================
# -*- coding: utf-8 -*-
"""Smooth quantization configuration."""
import enum
from dataclasses import dataclass, field
import omniconfig
from omniconfig import configclass
from ...utils.common import num2str
from ...utils.config import SkipBasedConfig
from .search import (
SearchBasedCalibConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
)
__all__ = [
"SmoothSpanMode",
"SmoothCalibConfig",
"SmoothAttentionCalibConfig",
"SkipBasedSmoothCalibConfig",
"SmoothTransfomerConfig",
]
class SmoothSpanMode(enum.Enum):
"""The mode for computing the span used in smoothing scale calculation."""
AbsMax = enum.auto()
RootMeanSquare = enum.auto()
@configclass
@dataclass
class SmoothCalibConfig(SearchBasedCalibConfig):
"""Configuration for smooth quantization.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`):
The granularity for quantization calibration.
element_batch_size (`int`, *optional*, default=`-1`):
The element batch size for calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
element_size (`int`, *optional*, default=`-1`):
The calibration element size.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
pre_reshape (`bool`, *optional*, default=`True`):
Whether to enable reshaping the tensor before calibration.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
fuse_when_possible (`bool`, *optional*, default=`True`):
Whether to fuse smooth scale whenever possible.
allow_a_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for alpha tensor.
allow_b_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for beta tensor.
spans (`list[tuple[SmoothSpanMode, SmoothSpanMode]]`, *optional*, default=`[]`):
The span combinations. The first element is for the alpha and the second element is for the beta.
alpha (`float`, *optional*, default=`0.5`):
The smoothing alpha.
beta (`float`, *optional*, default=`-1`):
The smoothing beta.
num_grids (`int`, *optional*, default=`20`):
The number of grids for grid search.
allow_low_rank (`bool`, *optional*, default=`False`):
Whether to allow quantization low-rank branch during calibration.
"""
fuse_when_possible: bool = True
allow_a_quant: bool = True
allow_b_quant: bool = True
spans: list[tuple[SmoothSpanMode, SmoothSpanMode]] = field(
default_factory=list,
metadata={
omniconfig.ARGPARSE_KWARGS: {
"nargs": "+",
"type": lambda s: tuple(SmoothSpanMode[x.split(".")[-1]] for x in s.split(",")),
}
},
)
a_spans: list[SmoothSpanMode] = field(default_factory=list, init=False)
b_spans: list[SmoothSpanMode] = field(default_factory=list, init=False)
alpha: float = 0.5
beta: float = -1
num_grids: int = 20
allow_low_rank: bool = False
def __post_init__(self) -> None: # noqa: C901
# region remove duplicates of ranges
_spans, _spanset, _a_spanset, _b_spanset = [], set(), set(), set()
self.a_spans, self.b_spans = [], []
for a_span, b_span in self.spans:
if isinstance(a_span, str):
a_span = SmoothSpanMode[a_span]
if isinstance(b_span, str):
b_span = SmoothSpanMode[b_span]
assert isinstance(a_span, SmoothSpanMode), f"Invalid span mode used for alpha: {a_span}"
assert isinstance(b_span, SmoothSpanMode), f"Invalid span mode used for beta: {b_span}"
_span = (a_span, b_span)
if _span in _spanset:
continue
_spans.append(_span)
_spanset.add(_span)
if a_span not in _a_spanset:
_a_spanset.add(a_span)
self.a_spans.append(a_span)
if b_span not in _b_spanset:
_b_spanset.add(b_span)
self.b_spans.append(b_span)
self.spans = _spans
# endregion
if self.strategy == SearchBasedCalibStrategy.Manual:
assert len(self.spans) == 1, "Only one span combination is allowed in manual mode"
assert self.alpha != 0 or self.beta != 0, "alpha and beta cannot be both zero"
self.alpha, self.beta = self.get_alpha_beta_pairs()[0]
if self.granularity == SearchBasedCalibGranularity.Group:
self.granularity = SearchBasedCalibGranularity.ChannelGroup
if self.allow_low_rank:
self.granularity = SearchBasedCalibGranularity.Layer
assert -3 <= self.alpha <= 1, "alpha must be less than or equal to 1"
assert -3 <= self.beta <= 1, "beta must be less than or equal to 1"
super().__post_init__()
def get_alpha_beta_pairs(self) -> list[tuple[float, float]]: # noqa: C901
"""Get the alpha and beta pairs for smooth quantization.
Returns:
`list[tuple[float, float]]`:
The alpha and beta pair candidates.
"""
if self.strategy == SearchBasedCalibStrategy.Manual:
if self.beta < 0:
assert 0 <= self.alpha <= 1, "alpha must be in [0, 1]"
return [(self.alpha, 1 - self.alpha)]
elif self.alpha < 0:
assert 0 <= self.beta <= 1, "beta must be in [0, 1]"
return [(1 - self.beta, self.beta)]
else:
assert 0 <= self.alpha <= 1, "alpha must be in [0, 1]"
assert 0 <= self.beta <= 1, "beta must be in [0, 1]"
return [(self.alpha, self.beta)]
choices = [i / self.num_grids for i in range(1, self.num_grids)]
if self.alpha > 0:
if self.beta > 0:
return [(0, 0)] + [(alpha, alpha) for alpha in choices]
if self.beta == 0:
return [(0, 0)] + [(alpha, 0) for alpha in choices]
if self.beta == -1:
return [(0, 0)] + [(alpha, 1 - alpha) for alpha in choices]
if self.beta == -2:
return [(0, 0)] + [(alpha, 0) for alpha in choices] + [(alpha, 1 - alpha) for alpha in choices]
return (
[(0, 0)] + [(alpha, 0) for alpha in choices] + [(alpha, beta) for alpha in choices for beta in choices]
)
if self.alpha == 0:
if self.beta > 0:
return [(0, 0)] + [(0, beta) for beta in choices]
if self.beta == 0:
return [(0, 0)] + [(alpha, 0) for alpha in choices] + [(0, beta) for beta in choices]
if self.beta == -1:
return [(0, 0)] + [(0, beta) for beta in choices] + [(alpha, 1 - alpha) for alpha in choices]
if self.beta == -2:
return (
[(0, 0)]
+ [(alpha, 0) for alpha in choices]
+ [(0, beta) for beta in choices]
+ [(alpha, 1 - alpha) for alpha in choices]
)
return (
[(0, 0)]
+ [(alpha, 0) for alpha in choices]
+ [(0, beta) for beta in choices]
+ [(alpha, beta) for alpha in choices for beta in choices]
)
if self.alpha == -1:
if self.beta > 0 or self.beta == -1:
return [(0, 0)] + [(alpha, 1 - alpha) for alpha in choices]
if self.beta == 0 or self.beta == -2:
return [(0, 0)] + [(alpha, 0) for alpha in choices] + [(alpha, 1 - alpha) for alpha in choices]
return (
[(0, 0)] + [(alpha, 0) for alpha in choices] + [(alpha, beta) for alpha in choices for beta in choices]
)
if self.alpha == -2:
if self.beta > 0 or self.beta == -1:
return [(0, 0)] + [(0, beta) for beta in choices] + [(alpha, 1 - alpha) for alpha in choices]
if self.beta == 0 or self.beta == -2:
return (
[(0, 0)]
+ [(alpha, 0) for alpha in choices]
+ [(0, beta) for beta in choices]
+ [(alpha, 1 - alpha) for alpha in choices]
)
return (
[(0, 0)]
+ [(alpha, 0) for alpha in choices]
+ [(0, beta) for beta in choices]
+ [(alpha, beta) for alpha in choices for beta in choices]
)
if self.alpha == -3:
if self.beta > 0:
return (
[(0, 0)]
+ [(0, beta) for beta in choices]
+ [(alpha, beta) for alpha in choices for beta in choices]
)
return (
[(0, 0)]
+ [(0, beta) for beta in choices]
+ [(alpha, 0) for alpha in choices]
+ [(alpha, beta) for alpha in choices for beta in choices]
)
raise ValueError("Invalid alpha and beta values")
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Get the directory names of the smooth quantization configuration.
Args:
prefix (`str`, *optional*, default=`""`):
The prefix of the directory.
Returns:
`list[str]`:
The directory names of the configuration.
"""
names = super().generate_dirnames(**kwargs)
names.append("[{}]".format("+".join(f"a.{a_span.name}.b.{b_span.name}" for a_span, b_span in self.spans)))
alpha, beta = num2str(self.alpha), num2str(self.beta)
if self.strategy == SearchBasedCalibStrategy.Manual:
names.append(f"a{alpha}.b{beta}")
elif self.alpha > 0:
names.append(f"g{self.num_grids}.b{beta}")
elif self.beta > 0:
names.append(f"g{self.num_grids}.a{alpha}")
else:
names.append(f"g{self.num_grids}.a{alpha}.b{beta}")
if self.allow_low_rank:
names[-1] += ".lr"
if not self.fuse_when_possible:
names[-1] += ".nf"
disallows = []
if not self.allow_a_quant:
disallows.append("a")
if not self.allow_b_quant:
disallows.append("b")
if disallows:
names.append(f"disallow.[{'+'.join(disallows)}]")
if prefix:
names = [f"{prefix}.{name}" for name in names]
return names
@configclass
@dataclass
class SkipBasedSmoothCalibConfig(SkipBasedConfig, SmoothCalibConfig):
"""Configuration for smooth quantization.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
objective (`SearchBasedCalibObjective`, *optional*, default=`SearchBasedCalibObjective.OutputsError`):
The objective for quantization calibration.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
granularity (`SearchBasedCalibGranularity`, *optional*, default=`SearchBasedCalibGranularity.Layer`):
The granularity for quantization calibration.
element_batch_size (`int`, *optional*, default=`-1`):
The element batch size for calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
element_size (`int`, *optional*, default=`-1`):
The calibration element size.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
pre_reshape (`bool`, *optional*, default=`True`):
Whether to enable reshaping the tensor before calibration.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
allow_a_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for alpha tensor.
allow_b_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for beta tensor.
spans (`list[tuple[SmoothSpanMode, SmoothSpanMode]]`, *optional*, default=`[]`):
The span combinations. The first element is for the alpha and the second element is for the beta.
alpha (`float`, *optional*, default=`0.5`):
The smoothing alpha.
beta (`float`, *optional*, default=`-1`):
The smoothing beta.
num_grids (`int`, *optional*, default=`20`):
The number of grids for grid search.
allow_low_rank (`bool`, *optional*, default=`False`):
Whether to allow quantization SVD during calibration.
skips (`list[str]`, *optional*, default=`[]`):
The keys of the modules to skip.
"""
pass
@configclass
@dataclass
class SmoothAttentionCalibConfig(SmoothCalibConfig):
"""Configuration for smooth quantization.
Args:
degree (`int`, *optional*, default=`2`):
The power degree for the quantization error. Defaults to `2`.
strategy (`SearchBasedCalibStrategy`, *optional*, default=`SearchBasedCalibStrategy.Manual`):
The strategy for quantization calibration.
sample_batch_size (`int`, *optional*, default=`-1`):
The samples batch size for calibration.
sample_size (`int`, *optional*, default=`-1`):
The calibration sample size.
outputs_device (`str`, *optional*, default=`"cpu"`):
The device to store the precomputed outputs of the module.
allow_a_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for alpha tensor.
allow_b_quant (`bool`, *optional*, default=`True`):
Whether to allow the quantization for beta tensor.
spans (`list[tuple[SmoothSpanMode, SmoothSpanMode]]`, *optional*, default=`[]`):
The span combinations. The first element is for the alpha and the second element is for the beta.
alpha (`float`, *optional*, default=`0.5`):
The smoothing alpha.
beta (`float`, *optional*, default=`-1`):
The smoothing beta.
num_grids (`int`, *optional*, default=`20`):
The number of grids for grid search.
"""
objective: SearchBasedCalibObjective = field(init=False, default=SearchBasedCalibObjective.OutputsError)
granularity: SearchBasedCalibGranularity = field(init=False, default=SearchBasedCalibGranularity.Layer)
element_batch_size: int = field(init=False, default=-1)
element_size: int = field(init=False, default=-1)
pre_reshape: bool = field(init=False, default=True)
allow_low_rank: bool = field(init=False, default=False)
@configclass
@dataclass
class SmoothTransfomerConfig:
"""Configuration for smooth quantization of transformer-based models.
Args:
proj (`SkipBasedSmoothCalibConfig` or `None`, *optional*, default=`None`):
The smooth configuration for projections.
attn (`SmoothAttentionCalibConfig` or `None`, *optional*, default=`None`):
The smooth configuration for attentions.
"""
proj: SkipBasedSmoothCalibConfig | None = None
attn: SmoothAttentionCalibConfig | None = None
@property
def enabled_proj(self) -> bool:
"""Whether the smooth quantization is enabled for projections."""
return self.proj is not None
@property
def enabled_attn(self) -> bool:
"""Whether the smooth quantization is enabled for attentions."""
return self.attn is not None
def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
"""Get the names of the smooth quantization configuration.
Args:
prefix (`str`, *optional*, default=`""`):
The prefix of the directory.
Returns:
`list[str]`:
The names of the smooth quantization configuration
"""
proj_names = self.proj.generate_dirnames(prefix="proj") if self.proj is not None else []
attn_names = self.attn.generate_dirnames(prefix="attn") if self.attn is not None else []
num_names = max(len(proj_names), len(attn_names))
names = []
for index in range(num_names):
name = []
if index < len(proj_names):
name.append(proj_names[index])
if index < len(attn_names):
name.append(attn_names[index])
names.append("-".join(name))
if prefix:
names = [f"{prefix}.{name}" for name in names]
return names
================================================
FILE: deepcompressor/calib/lowrank.py
================================================
# -*- coding: utf-8 -*-
"""Quantization SVD calibration module."""
from dataclasses import _MISSING_TYPE, MISSING
import torch
import torch.nn as nn
from ..data.common import TensorType
from ..nn.patch.lowrank import LowRankBranch
from ..quantizer.processor import Quantizer
from ..utils import math, tools
from ..utils.config import KeyEnableConfig
from .config import QuantLowRankCalibConfig, SearchBasedCalibObjective
from .search import SearchBasedCalibrator
__all__ = ["QuantLowRankCalibrator"]
class QuantLowRankCalibrator(SearchBasedCalibrator[QuantLowRankCalibConfig, LowRankBranch]):
"""The quantization low-rank branch calibrator."""
def __init__(
self,
config: QuantLowRankCalibConfig,
w_quantizer: Quantizer,
x_quantizer: Quantizer | None,
develop_dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the calibrator.
Args:
config (`QuantLowRankCalibConfig`):
The configuration of the quantization low-rank branch calibrator.
w_quantizer (`Quantizer`):
The quantizer for weights.
x_quantizer (`Quantizer` or `None`):
The quantizer for inputs.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
"""
if isinstance(config, KeyEnableConfig):
assert config.is_enabled_for(w_quantizer.key), "The calibrator should be enabled for the quantizer."
else:
assert config.is_enabled(), "The calibrator should be enabled."
super().__init__(
tensor_type=TensorType.Weights,
config=config,
w_quantizer=w_quantizer,
x_quantizer=x_quantizer,
y_quantizer=None,
develop_dtype=develop_dtype,
)
assert self.needs_quant, "The tensor should be quantized."
self.num_iters = config.num_iters
@property
def population_size(self) -> int:
"""Return the population size of the current iteration."""
return 1
@property
def allows_x_quant_for_wgts(self) -> bool:
"""Whether the calibrator allows input quantization when tensor_type is Weights."""
return True
@property
def allows_w_quant_for_wgts(self) -> bool:
"""Whether the calibrator needs weight quantization when tensor_type is Weights."""
return True
def is_done(self) -> bool:
"""Check if the calibration is done."""
return self.iter >= self.num_iters or self.early_stopped
def is_last_iter(self) -> bool:
"""Check if the current iteration is the last one."""
return self.iter == self.num_iters - 1
def _reset(self, x_wgts: list[torch.Tensor | nn.Parameter], **kwargs) -> None: # noqa: C901
"""Reset the calibrator.
Args:
x_wgts (`list[torch.Tensor | nn.Parameter]`):
The weights in x-w computation.
"""
self.best_branch: LowRankBranch = None
self.best_error: torch.Tensor = None
self.error_history: list[tuple[float, float]] = []
self.early_stopped = False
if len(x_wgts) > 1 and not self.config.exclusive:
self.w = torch.cat([wgt.data for wgt in x_wgts], dim=0)
else:
assert len(x_wgts) == 1
self.w = x_wgts[0].data
if self.config.compensate:
self.qw = torch.cat(
[
self.w_quantizer.quantize(wgt.data, kernel=None, develop_dtype=self.develop_dtype).data
for wgt in x_wgts
],
dim=0,
)
else:
self.qw = 0
self.hat_ws: list[torch.Tensor] = [None] * len(x_wgts)
self.ocs: list[int] = [wgt.shape[0] for wgt in x_wgts]
def get_best(self) -> LowRankBranch:
"""Get the best candidate.
Returns:
`LowRankBranch`:
The best candidate.
"""
return self.best_branch
def _ask(self) -> LowRankBranch:
"""Ask for the next candidate.
Returns:
`LowRankBranch`:
The next candidate.
"""
branch = LowRankBranch(
self.w.shape[1],
self.w.shape[0],
rank=self.config.rank,
weight=self.w - self.qw,
)
self.wgt_idx = 0
if len(self.hat_ws) > 1:
lw = branch.get_effective_weight().view(self.w.shape)
rw = self.w - lw
oc_idx = 0
for idx, oc in enumerate(self.ocs):
self.hat_ws[idx] = self.w_quantizer.quantize(
rw[oc_idx : oc_idx + oc], kernel=None, develop_dtype=self.develop_dtype
).data
oc_idx += oc
self.qw = torch.cat(self.hat_ws, dim=0)
if self.objective != SearchBasedCalibObjective.OutputsError:
oc_idx = 0
for idx, oc in enumerate(self.ocs):
self.hat_ws[idx].add_(lw[oc_idx : oc_idx + oc])
oc_idx += oc
else:
lw = branch.get_effective_weight().view(self.w.shape)
self.qw = self.w_quantizer.quantize(self.w - lw, kernel=None, develop_dtype=self.develop_dtype).data
if self.objective != SearchBasedCalibObjective.OutputsError:
self.hat_ws = [self.qw + lw]
else:
self.hat_ws = [self.qw]
return branch
def _tell(self, error: list[torch.Tensor]) -> None: # noqa: C901
"""Tell the error of the last candidate and update the best candidate.
Args:
errors (list[torch.Tensor]): The error of the last candidate.
"""
if len(error) > 1:
error = [sum(error)]
error = error[0]
assert isinstance(error, torch.Tensor)
assert error.numel() == 1, "The error should only have one value."
if self.best_error is None or error <= self.best_error:
self.best_error = error
self.best_branch = self.candidate
elif self.config.early_stop:
self.early_stopped = True
if self.logger.level <= tools.logging.DEBUG:
self.error_history.append(
(
math.root_(error.to(torch.float64), self.config.degree).item(),
math.root_(self.best_error.to(torch.float64), self.config.degree).item(),
)
)
if self.iter % 10 == 9 or self.is_last_iter() or self.early_stopped:
iter_end = ((self.iter + 10) // 10) * 10
iter_start = iter_end - 10
iter_end = min(iter_end, self.iter + 1)
history = self.error_history[iter_start:iter_end]
self.logger.debug(" - iter = [%s]", ", ".join(f"{i:10d}" for i in range(iter_start, iter_end)))
self.logger.debug(" - error = [%s]", ", ".join(f"{e[0]:10.4f}" for e in history))
self.logger.debug(" - best error = [%s]", ", ".join(f"{e[1]:10.4f}" for e in history))
def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if not self.needs_x_quant_for_wgts:
return x
return self.x_quantizer.quantize(x, channels_dim=channels_dim, develop_dtype=self.develop_dtype).data
def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
hat_w = self.hat_ws[self.wgt_idx]
self.hat_ws[self.wgt_idx] = None
self.wgt_idx += 1
return hat_w if self.needs_w_quant_for_wgts else w
def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
raise RuntimeError("_process_y_in_yx should not be called in QuantSVDCalibrator.")
def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
raise RuntimeError("_process_x_in_yx should not be called in QuantSVDCalibrator.")
def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_xw_in_yx should not be called in QuantSVDCalibrator.")
def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_yw_in_yx should not be called in QuantSVDCalibrator.")
def _process_wgts_centric_mod(
self, wgts: list[nn.Parameter], mods: list[nn.Module], update_state_dict: bool = True, **kwargs
) -> None:
assert len(self.hat_ws) == len(wgts) == len(mods)
shared = self.candidate
if len(self.hat_ws) > 1:
oc_idx = 0
for mod, wgt, hat_w in zip(mods, wgts, self.hat_ws, strict=True):
if update_state_dict:
self._state_dict.append((wgt, wgt.data))
wgt.data = hat_w
branch = LowRankBranch(wgt.shape[1], wgt.shape[0], rank=self.config.rank)
branch.a = shared.a
branch.b.to(dtype=wgt.dtype, device=wgt.device)
branch.b.weight.copy_(shared.b.weight[oc_idx : oc_idx + wgt.data.shape[0]])
oc_idx += wgt.data.shape[0]
self._hooks.append(branch.as_hook().register(mod))
else:
if update_state_dict:
self._state_dict.append((wgts[0], wgts[0].data))
wgts[0].data = self.hat_ws[0]
self._hooks.append(shared.as_hook().register(mods))
if self.needs_x_quant_for_wgts:
self._hooks.append(self.x_quantizer.as_hook().register(mods))
self.hat_ws = [None] * len(self.hat_ws)
================================================
FILE: deepcompressor/calib/metric.py
================================================
# -*- coding: utf-8 -*-
"""Channel-wise metric calculation module."""
import typing as tp
import torch
from ..data.utils.shape import infer_view_shape
__all__ = ["ChannelMetric"]
class ChannelMetric:
"""Channel-wise metric."""
@staticmethod
def _normalize(
tensor: torch.Tensor,
group_shape: tp.Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
shape, ndim = tensor.shape, tensor.ndim
view_shape = infer_view_shape(tensor.shape, group_shape)
# (d0, d1, d2, ...) -> (#g0, gs0, #g1, gs1, #g2, gs2, ...)
tensor = tensor.view(view_shape).to(dtype=dtype)
tensor_max = tensor.abs().amax(dim=list(range(1, ndim * 2, 2)), keepdim=True)
tensor_max[tensor_max == 0] = 1
tensor = tensor / tensor_max
return tensor.view(shape)
@staticmethod
def _abs_max(
tensor: torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, int]:
return (
tensor.view(tensor.shape[0], num_channels, -1)
.abs()
.amax(dim=(0, 2))
.view(-1)
.to(dtype=dtype, device=device),
1,
)
@staticmethod
def _abs_sum(
tensor: torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, int]:
tensor = tensor.view(tensor.shape[0], num_channels, -1)
cnt = tensor.shape[0] * tensor.shape[2]
return tensor.abs().to(dtype=dtype).sum(dim=(0, 2)).view(-1).to(device=device), cnt
@staticmethod
def _abs_normalize_sum(
tensor: torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, int]:
return ChannelMetric._abs_sum(
ChannelMetric._normalize(tensor, group_shape, dtype=dtype),
num_channels,
group_shape,
device=device,
dtype=dtype,
)
@staticmethod
def _square_sum(
tensor: torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, int]:
tensor = tensor.view(tensor.shape[0], num_channels, -1)
cnt = tensor.shape[0] * tensor.shape[2]
return tensor.to(dtype=dtype).pow(2).sum(dim=(0, 2)).view(-1).to(device=device), cnt
@staticmethod
def _max_reduce(
fn: tp.Callable[
[torch.Tensor, int, tp.Sequence[int], torch.device, torch.dtype],
tuple[torch.Tensor, torch.Tensor | int | float],
],
tensors: tp.Sequence[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor | int | float]:
if isinstance(tensors, torch.Tensor):
device = device or tensors.device
return fn(tensors, num_channels, group_shape, device, dtype)
else:
rst_0, rst_1 = ChannelMetric._max_reduce(fn, tensors[0], num_channels, group_shape, device, dtype)
for tensor in tensors[1:]:
_rst_0, _rst_1 = ChannelMetric._max_reduce(fn, tensor, num_channels, group_shape, device, dtype)
rst_0 = torch.maximum(rst_0, _rst_0.to(device=rst_0.device))
if isinstance(rst_1, torch.Tensor):
rst_1 = torch.maximum(rst_1, _rst_1.to(device=rst_1.device))
else:
rst_1 = max(rst_1, _rst_1)
return rst_0, rst_1
@staticmethod
def _sum_reduce(
fn: tp.Callable[
[torch.Tensor, int, tp.Sequence[int], torch.device, torch.dtype],
tuple[torch.Tensor, torch.Tensor | int | float],
],
tensors: tp.Sequence[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor | int | float]:
if isinstance(tensors, torch.Tensor):
device = device or tensors.device
return fn(tensors.to(device), num_channels, group_shape, device, dtype)
else:
assert isinstance(tensors, (list, tuple))
rst_0, rst_1 = ChannelMetric._sum_reduce(fn, tensors[0], num_channels, group_shape, device, dtype)
for tensor in tensors[1:]:
_rst_0, _rst_1 = ChannelMetric._sum_reduce(fn, tensor, num_channels, group_shape, device, dtype)
rst_0 += _rst_0.to(device=rst_0.device)
if isinstance(rst_1, torch.Tensor):
rst_1 += _rst_1.to(device=rst_1.device)
else:
rst_1 += _rst_1
return rst_0, rst_1
@staticmethod
def abs_max(
tensors: tp.Iterable[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Get the absolute maximum of the tensors, where `R[i] = AbsMax(T[i, :])`."""
return ChannelMetric._max_reduce(
ChannelMetric._abs_max, tensors, num_channels, group_shape, device=device, dtype=dtype
)[0]
@staticmethod
def abs_mean(
tensors: tp.Iterable[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Get the absolute mean of the tensors, where `R[i] = AbsMean(T[i, :])`."""
rst, cnt = ChannelMetric._sum_reduce(
ChannelMetric._abs_sum, tensors, num_channels, group_shape, device=device, dtype=dtype
)
return rst.div_(cnt)
@staticmethod
def abs_normalize_mean(
tensors: tp.Iterable[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Get the absolute group normalized mean of the tensors, where `R[i] = Mean(U[i, :])`
and `U[i,j] = Abs(T[i, j]) / AbsMax(T[:, j]))`."""
rst, cnt = ChannelMetric._sum_reduce(
ChannelMetric._abs_normalize_sum, tensors, num_channels, group_shape, device=device, dtype=dtype
)
return rst.div_(cnt)
@staticmethod
def root_mean_square(
tensors: tp.Iterable[torch.Tensor] | torch.Tensor,
num_channels: int,
group_shape: tp.Sequence[int],
device: torch.device | str = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Get the root mean square of the tensors, where `R[i] = Root(Mean(T[i, :]^2))`."""
rst, cnt = ChannelMetric._sum_reduce(
ChannelMetric._square_sum, tensors, num_channels, group_shape, device=device, dtype=dtype
)
return rst.div_(cnt).sqrt_()
================================================
FILE: deepcompressor/calib/range.py
================================================
# -*- coding: utf-8 -*-
"""Quantization dynamic range calibration."""
import gc
import typing as tp
from dataclasses import _MISSING_TYPE, MISSING
import torch
import torch.nn as nn
from ..data.cache import TensorsCache
from ..data.common import TensorType
from ..data.range import DynamicRange
from ..data.scale import QuantScale
from ..data.utils.shape import infer_view_shape
from ..quantizer.impl.info import QuantInfo
from ..quantizer.processor import Quantizer
from ..utils import math, tools
from .config import DynamicRangeCalibConfig, SearchBasedCalibGranularity
from .search import SearchBasedCalibrator
__all__ = ["DynamicRangeCalibrator", "calibrate_dynamic_range"]
class DynamicRangeCalibrator(SearchBasedCalibrator[DynamicRangeCalibConfig, DynamicRange]):
"""The quantization dynamic range calibrator."""
def __init__(
self,
tensor_type: TensorType,
config: DynamicRangeCalibConfig,
static: bool,
quantizer: Quantizer,
pre_scale: torch.Tensor | None = None,
) -> None:
"""Initialize the calibrator.
Args:
tensor_type (`TensorType`):
The tensor type.
config (`DynamicRangeCalibConfig`):
The dynamic range calibration configuration.
static (`bool`):
Whether the dynamic range is static, i.e., whether the quantization is static.
quantizer (`Quantizer`):
The quantizer.
pre_scale (`torch.Tensor` or `None`):
The joint scale tensor of the previous quantization steps.
"""
super().__init__(
tensor_type=tensor_type,
config=config,
w_quantizer=quantizer if tensor_type == TensorType.Weights else None,
x_quantizer=quantizer if tensor_type == TensorType.Inputs else None,
y_quantizer=quantizer if tensor_type == TensorType.Outputs else None,
develop_dtype=quantizer.develop_dtype,
)
assert self.needs_quant, "The tensor should be quantized."
self.static = static
self.pre_scale = pre_scale
self.ratios = self.config.get_ratios()
self.num_iters = len(self.ratios)
@property
def population_size(self) -> int:
"""Return the population size of the current iteration."""
return len(self.ratios[self.iter])
def is_clamp_based(self) -> bool:
"""Return whether the calibration is clamp-based."""
return self.static or not self.config.allow_scale
def _reset( # noqa: C901
self,
x_wgts: list[torch.Tensor | nn.Parameter],
x_acts: TensorsCache | None,
y_acts: TensorsCache | None,
**kwargs,
) -> None:
"""Reset the calibrator.
Args:
x_wgts (`list[torch.Tensor | nn.Parameter]`):
The weights in x-w computation.
x_acts (`TensorsCache` or `None`):
The x activations in x-w computation.
y_acts (`TensorsCache` or `None`):
The y activations in y-x computation.
"""
self.base_range: DynamicRange = DynamicRange()
self.best_range: DynamicRange = DynamicRange()
self.best_error: torch.Tensor = None
self.error_history: list[tuple[float, float]] = []
self.device = None
if self.tensor_type == TensorType.Weights:
assert len(x_wgts) == 1, "The weight should be a single tensor."
wgts = x_wgts[0].data
assert isinstance(wgts, torch.Tensor), "The weight should be a tensor."
tensors = [wgts]
self.device = wgts.device
elif self.tensor_type == TensorType.Inputs:
assert x_acts is not None, "The input activations should be provided."
assert x_acts.num_tensors == 1, f"Only one input is allowed, got {x_acts.num_tensors}"
acts = x_acts.front()
tensors = acts.get_standardized_data(reshape=False)
self.device = acts.orig_device
else:
assert y_acts is not None, "The output activations should be provided."
assert y_acts.num_tensors == 1, f"Only one output is allowed, got {y_acts.num_tensors}"
acts = y_acts.front()
tensors = acts.get_standardized_data(reshape=False)
self.device = acts.orig_device
shape = tensors[0].shape
view_shape = infer_view_shape(
shape,
self.quantizer.config.largest_group_shape,
skip_first_dim=self.tensor_type != TensorType.Weights,
)
# region get range scale shape
self.pos_view_shape = torch.Size([1, 1, view_shape[2], *([1] * (len(view_shape) - 3))])
self.range_shape = torch.Size([gs if i % 2 == 0 else 1 for i, gs in enumerate(view_shape)])
if self.granularity == SearchBasedCalibGranularity.Layer:
self.ratio_shape = self.error_shape = torch.Size((1,))
self.ratio_view_shape = self.ratio_shape
elif self.granularity == SearchBasedCalibGranularity.ChannelGroup:
self.ratio_shape = self.error_shape = torch.Size((view_shape[2],))
self.ratio_view_shape = self.pos_view_shape
elif self.granularity == SearchBasedCalibGranularity.Group:
self.ratio_shape = self.error_shape = torch.Size(view_shape[::2])
self.ratio_view_shape = self.range_shape
else:
raise ValueError(f"Invalid granularity: {self.granularity}")
assert self.ratio_shape.numel() == self.ratio_view_shape.numel()
if self.pre_scale is not None:
assert len(shape) * 2 == len(self.pre_scale.shape)
self.pre_view_shape = infer_view_shape(shape, self.pre_scale.shape[1::2])
else:
self.pre_view_shape = torch.Size()
# endregion
if self.is_clamp_based():
if self.pre_scale is not None:
tensors = [self._preprocess_with_pre_scale(t) for t in tensors]
tensors = [t.view(view_shape).to(dtype=self.develop_dtype) for t in tensors]
self.base_range = DynamicRange.construct(
tensors,
zero_domain=self.quantizer.config.zero_domain,
is_float_point=self.quantizer.config.quant_dtype.is_float_point,
)
gc.collect()
torch.cuda.empty_cache()
def get_best(self) -> DynamicRange:
"""Get the best candidate.
Returns:
`DynamicRange`:
The best candidate.
"""
if self.static:
return DynamicRange(min=self.best_range.min, max=self.best_range.max)
elif self.is_clamp_based():
return DynamicRange(min=self.best_range.min, max=self.best_range.max, ratio=1.0)
else:
return DynamicRange(ratio=self.best_range.ratio.view(self.ratio_view_shape))
def _ask(self) -> DynamicRange:
"""Ask for the next candidate.
Returns:
`DynamicRange`:
The next candidate.
"""
ratio = self.ratios[self.iter][self.candidate_id]
if self.is_clamp_based():
return self.base_range.scale(
ratio=ratio,
zero_domain=self.quantizer.config.zero_domain,
is_float_point=self.quantizer.config.quant_dtype.is_float_point,
)
else:
return DynamicRange(ratio=ratio)
def _tell(self, error: list[torch.Tensor]) -> None: # noqa: C901
"""Tell the error of the last candidate and update the best candidate.
Args:
errors (`list[torch.Tensor]`):
The error of the last candidate.
"""
assert len(error) == 1, "The error should only have one value."
error = error[0]
assert isinstance(error, torch.Tensor)
assert error.shape == self.error_shape, f"Error shape {error.shape} != {self.error_shape}."
assert isinstance(self.candidate, DynamicRange)
candidate_ratio = self.ratios[self.iter][self.candidate_id]
if self.best_error is None:
self.best_error = error
if self.is_clamp_based():
self.best_range.min = self.candidate.min
self.best_range.max = self.candidate.max
self.best_range.ratio = torch.full(
size=self.ratio_shape, fill_value=candidate_ratio, device=self.device, dtype=self.develop_dtype
)
elif error.numel() > 1:
pos = error < self.best_error
self.best_error[pos] = error[pos]
if self.is_clamp_based():
if self.error_shape.numel() != self.range_shape.numel():
pos = pos.view(self.pos_view_shape).expand(*self.range_shape)
else:
pos = pos.view(self.range_shape)
self.best_range.max[pos] = self.candidate.max[pos]
if isinstance(self.candidate.min, torch.Tensor):
self.best_range.min[pos] = self.candidate.min[pos]
self.best_range.ratio[pos.view(self.ratio_shape)] = candidate_ratio
elif error < self.best_error:
self.best_error = error
if self.is_clamp_based():
self.best_range.min = self.candidate.min
self.best_range.max = self.candidate.max
self.best_range.ratio.fill_(candidate_ratio)
if self.logger.level <= tools.logging.DEBUG:
self.error_history.append(
(
math.root_(error.to(torch.float64).sum(), self.config.degree).item(),
math.root_(self.best_error.to(torch.float64).sum(), self.config.degree).item(),
)
)
if self.is_last_candidate_in_iter():
stype_id = self.iter
ratios, population_size = self.ratios[stype_id], self.population_size
for i in range(0, population_size, 5):
self.logger.debug(
" - range ratio = [%s]",
", ".join(f"{ratios[j]:10.4f}" for j in range(i, min(i + 5, population_size))),
)
self.logger.debug(
" sum error = [%s]",
", ".join(f"{self.error_history[j][0]:10.4f}" for j in range(i, min(i + 5, population_size))),
)
self.logger.debug(
" best error = [%s]",
", ".join(f"{self.error_history[j][1]:10.4f}" for j in range(i, min(i + 5, population_size))),
)
self.error_history.clear()
if self.is_last_iter():
self.logger.debug(
"+ error = [%.4f]",
math.root_(self.best_error.to(torch.float64).sum(), self.config.degree).item(),
)
def _preprocess_with_pre_scale(self, t: torch.Tensor) -> torch.Tensor:
t = t.view(self.pre_view_shape)
t = t.to(dtype=self.develop_dtype) if t.dtype != self.develop_dtype else t.clone()
t = t.div_(self.pre_scale)
if self.quantizer.range_bound is not None and self.quantizer.range_bound.is_set():
t = t.clamp_(min=self.quantizer.range_bound.min, max=self.quantizer.range_bound.max)
return t
def _process_wxy(self, tensor: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
shape, dtype = tensor.shape, tensor.dtype
if self.pre_scale is not None:
tensor = self._preprocess_with_pre_scale(tensor).view(shape)
tensor = self.quantizer.quantize(
tensor,
kernel=None,
channels_dim=channels_dim,
dynamic_range=self.candidate,
default_dtype=dtype,
develop_dtype=self.develop_dtype,
).data
if self.pre_scale is not None:
tensor = tensor.view(self.pre_view_shape).mul_(self.pre_scale).to(dtype)
tensor = tensor.view(shape)
return tensor
def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if self.tensor_type != TensorType.Inputs:
return x
return self._process_wxy(x, channels_dim)
def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
if self.tensor_type != TensorType.Weights:
return w
return self._process_wxy(w, channels_dim=None)
def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if self.tensor_type != TensorType.Outputs:
return y
return self._process_wxy(y, channels_dim)
def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
raise RuntimeError("_process_x_in_yx should not be called in DynamicRangeCalibrator.")
def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_xw_in_yx should not be called in DynamicRangeCalibrator.")
def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_yw_in_yx should not be called in DynamicRangeCalibrator.")
def calibrate_dynamic_range(
tensor_type: TensorType,
config: DynamicRangeCalibConfig | None,
static: bool,
quantizer: Quantizer,
modules: tp.Sequence[nn.Module],
activations: TensorsCache,
weights: tp.Sequence[nn.Parameter] | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_weights: tp.Sequence[tuple[nn.Parameter, torch.Tensor]] | None = None,
orig_activations: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
) -> tp.Sequence[DynamicRange] | None:
"""Calibrate the dynamic range.
Args:
tensor_type (`TensorType`):
The tensor type.
config (`DynamicRangeCalibConfig`):
The quantization dynamic range calibration configuration.
static (`bool`):
Whether the dynamic range is static.
quantizer (`Quantizer`):
The quantizer.
modules (`Sequence[nn.Module]`):
The modules to calibrate.
activations (`TensorsCache`):
The inputs cache if the tensor type is not outputs, or the outputs cache if the tensor type is outputs.
weights (`Sequence[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights to calibrate.
If not provided, the weights of the modules will be used.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs for evaluation.
If not provided, the `activations` cache will be used.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
If not provided, the module to calibrate will be used.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
orig_weights (`Sequence[tuple[nn.Parameter, torch.Tensor]]` or `None`, *optional*, defaults to `None`):
The original weights.
orig_activations (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original activations.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original evaluation inputs.
Returns:
`Sequence[DynamicRange]` or `None`:
The dynamic ranges of each quantization step.
"""
if config is None or not quantizer.is_enabled():
return None
decomposed_config = quantizer.config.decompose()
num_steps = decomposed_config.num_steps
# region dynamic range without search
if not config.needs_search and (not static or tensor_type == TensorType.Weights):
if config.ratio != 1.0:
dynamic_range = DynamicRange(ratio=config.ratio)
return tuple([dynamic_range] + [None] * (num_steps - 1))
else:
return None
# endregion
# region prepare for search
if weights is None:
weights = [module.weight for module in modules if hasattr(module, "weight")]
if tensor_type == TensorType.Weights:
assert len(modules) == 1, "only one module is supported for weight quantization calibration"
assert len(weights) == 1, "only one weight is supported for weight quantization calibration"
if eval_module is None:
eval_module = modules[0]
if eval_inputs is None:
eval_inputs = activations
else:
assert eval_inputs is not None, "eval_inputs is required when eval_module is provided"
else:
assert activations is not None, "activations is required for activation quantization calibration"
assert activations.num_tensors == 1, "only one tensor is supported for activation quantization calibration"
if tensor_type != TensorType.Outputs:
x_wgts, x_acts, x_mods, orig_x_wgts, orig_x_acts = weights, activations, modules, orig_weights, orig_activations
y_wgts, y_acts, y_mods, orig_y_wgts, orig_y_acts = [], None, None, None, None
else:
x_wgts, x_acts, x_mods, orig_x_wgts, orig_x_acts = [], None, None, None, None
y_wgts, y_acts, y_mods, orig_y_wgts, orig_y_acts = weights, activations, modules, orig_weights, orig_activations
# endregion
if num_steps == 1:
dynamic_range = DynamicRangeCalibrator(
tensor_type=tensor_type,
config=config,
static=static,
quantizer=quantizer,
).calibrate(
x_wgts=x_wgts,
y_wgts=y_wgts,
x_acts=x_acts,
y_acts=y_acts,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
x_mods=x_mods,
y_mods=y_mods,
orig_x_wgts=orig_x_wgts,
orig_y_wgts=orig_y_wgts,
orig_x_acts=orig_x_acts,
orig_y_acts=orig_y_acts,
orig_eval_inputs=orig_eval_inputs,
)
return (dynamic_range,)
# region prepare for search with progressive quantization
if tensor_type == TensorType.Weights:
tensor = weights[0].detach().data
else:
assert activations.num_tensors == 1, "Only one tensor is supported for activation quantization"
acts = activations.front()
assert len(acts.data) == 0, "Only one tensor is supported for activation quantization"
tensor = acts.data[0].detach().data
if acts.channels_dim is not None:
tensor = tensor.reshape(-1, *tensor.shape[acts.channels_dim :])
develop_dtype = quantizer.develop_dtype
default_scale_dtype = quantizer.default_dtype or tensor.dtype
develop_tensor = tensor.to(dtype=develop_dtype) if tensor.dtype != develop_dtype else tensor.clone()
del tensor
# endregion
info = QuantInfo.construct(
decomposed_config,
tensor_shape=develop_tensor.shape,
default_dtype=default_scale_dtype,
quant_range=quantizer.quant_range,
range_bound=quantizer.range_bound,
)
dynamic_ranges = []
quant_scale = QuantScale()
for step, step_info in enumerate(info.steps):
step_quantizer = Quantizer(
config=step_info.to_config(),
kernel=quantizer.kernel if step == num_steps - 1 else None,
quant_range=step_info.quant_range,
range_bound=step_info.range_bound,
default_dtype=quantizer.default_dtype,
develop_dtype=quantizer.develop_dtype,
)
step_dynamic_range = DynamicRangeCalibrator(
tensor_type=tensor_type,
config=config,
static=static,
quantizer=step_quantizer,
pre_scale=quant_scale.data,
).calibrate(
x_wgts=x_wgts,
y_wgts=y_wgts,
x_acts=x_acts,
y_acts=y_acts,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
x_mods=x_mods,
y_mods=y_mods,
orig_x_wgts=orig_x_wgts,
orig_y_wgts=orig_y_wgts,
)
dynamic_ranges.append(step_dynamic_range)
step_scale, _ = step_info.scale.quantize(
tensor=develop_tensor.view(step_info.tensor_shape),
dynamic_range=step_dynamic_range,
)
quant_scale.append(step_scale)
if num_steps > 2 and step < num_steps - 1:
step_quant_range = step_info.tensor_quant_range
develop_tensor = develop_tensor.view(step_info.tensor_view_shape).div_(step_scale.data)
develop_tensor = develop_tensor.clamp_(min=step_quant_range.min, max=step_quant_range.max)
return tuple(dynamic_ranges)
================================================
FILE: deepcompressor/calib/reorder.py
================================================
# -*- coding: utf-8 -*-
"""Channel reordering module."""
import gc
import typing as tp
from dataclasses import _MISSING_TYPE, MISSING, dataclass
import torch
import torch.nn as nn
from ..data.cache import TensorsCache
from ..data.common import TensorType
from ..quantizer.processor import Quantizer
from ..utils import math, tools
from ..utils.hooks import BaseInputPackager, BaseOutputPackager, BaseTensorProcessor
from .config import (
ChannelOrderCalibConfig,
SearchBasedCalibGranularity,
SearchBasedCalibObjective,
SearchBasedCalibStrategy,
)
from .metric import ChannelMetric
from .search import SearchBasedCalibrator
__all__ = ["ChannelOrderCalibrator", "ChannelReorderer"]
@dataclass
class ChannelReorderer(BaseTensorProcessor):
"""Activation channel reordering processor."""
index: torch.Tensor
channels_dim: int
# region hook-related attributes
input_packager: BaseInputPackager | None = None
output_packager: BaseOutputPackager | None = None
# endregion
def is_enabled(self) -> bool:
return self.index is not None
def get_input_packager(self) -> BaseInputPackager | None:
return self.input_packager
def get_output_packager(self) -> BaseOutputPackager | None:
return self.output_packager
def process(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process the tensor.
Args:
tensor (torch.Tensor): The tensor to process.
Returns:
torch.Tensor: The processed tensor.
"""
self.index = self.index.to(device=tensor.device)
return tensor.index_select(dim=self.channels_dim, index=self.index)
def get_channel_index_from_rank(
rank: torch.Tensor,
num_channels: int,
num_groups: int,
index_mode: ChannelOrderCalibConfig.ChannelIndex,
) -> torch.Tensor:
"""Get the index from the rank.
Args:
rank (`torch.Tensor`):
The rank of the channels.
num_channels (`int`):
The number of channels.
num_groups (`int`):
The number of groups.
index_mode (`ChannelOrderCalibConfig.ChannelIndex`):
The index mode.
Returns:
`torch.Tensor`:
The index of the channels, i.e., the order of the channels.
"""
if index_mode == ChannelOrderCalibConfig.ChannelIndex.Transpose:
return rank.view(num_channels // num_groups, num_groups).t().reshape(-1)
elif index_mode == ChannelOrderCalibConfig.ChannelIndex.Sequential:
return rank
else:
raise ValueError(f"Unsupported index mode: {index_mode}")
def get_channel_metric(
inputs: TensorsCache,
weights: tp.Sequence[torch.Tensor],
metric_mode: ChannelOrderCalibConfig.ChannelMetric,
num_channels: int,
num_heads: int = 1,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Get the metric value of the channels.
Args:
inputs (`TensorsCache`):
The input activations.
weights (`Sequence[torch.Tensor]`):
The weight tensors.
metric_mode (`ChannelOrderCalibConfig.ChannelMetric`):
The channel metric mode.
num_channels (`int`):
The number of channels.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
device (`torch.device` or `str` or `None`, *optional*, defaults to `None`):
The device of the metric value tensor.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The data type of the metric value tensor.
Returns:
`torch.Tensor`:
The metric value of the channels.
"""
metric_name = metric_mode.name
if metric_name.endswith("Product"):
metric_name = metric_name[:-7]
ipts_metric = get_channel_metric(
inputs=inputs,
weights=weights,
metric_mode=ChannelOrderCalibConfig.ChannelMetric[f"Inputs{metric_name}"],
num_channels=num_channels,
num_heads=num_heads,
device=device,
dtype=dtype,
)
wgts_metric = get_channel_metric(
inputs=inputs,
weights=weights,
metric_mode=ChannelOrderCalibConfig.ChannelMetric[f"Weights{metric_name}"],
num_channels=num_channels,
num_heads=num_heads,
device=device,
dtype=dtype,
)
return ipts_metric * wgts_metric
else:
if metric_name.startswith("Inputs"):
assert inputs.num_tensors == 1, f"Only one input source is allowed, got {inputs.num_tensors}"
metric_name, tensors = metric_name[6:], inputs.front().get_standardized_data(reshape=False)
else:
assert metric_name.startswith("Weights")
metric_name, tensors = metric_name[7:], weights
group_shape = [-1] * tensors[0].ndim
group_shape[1] = num_channels // num_heads
# convert metric name from camel case to snake case
metric_name = "".join(["_" + c.lower() if c.isupper() else c for c in metric_name])
metric_name = metric_name.lstrip("_")
metric_fn = getattr(ChannelMetric, metric_name)
return metric_fn(tensors, num_channels, group_shape, device=device, dtype=dtype).view(num_channels)
def update_channel_metric(
metric: torch.Tensor | None,
inputs: TensorsCache,
weights: tp.Sequence[torch.Tensor],
metric_mode: ChannelOrderCalibConfig.ChannelMetric,
num_channels: int,
num_heads: int = 1,
device: torch.device | str = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Update the metric value of the channels.
Args:
metric (`torch.Tensor` or `None`):
The metric value of the channels.
inputs (`TensorsCache`):
The input activations.
weights (`Sequence[torch.Tensor]`):
The weight tensors.
metric_mode (`ChannelOrderCalibConfig.ChannelMetric`):
The channel metric mode.
num_channels (`int`):
The number of channels.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
device (`torch.device` or `str`, *optional*, defaults to `None`):
The device of the metric value tensor.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The data type of the metric value tensor.
Returns:
`torch.Tensor`:
The updated metric value of the channels.
"""
_metric = get_channel_metric(
inputs=inputs,
weights=weights,
metric_mode=metric_mode,
num_channels=num_channels,
num_heads=num_heads,
device=device,
dtype=dtype,
)
if metric is None:
return _metric
elif "Max" in metric_mode.name:
return torch.maximum(metric, _metric)
else:
return metric.add_(_metric)
def init_channel_index_from_metric(
metric: torch.Tensor,
/,
metric_mode: ChannelOrderCalibConfig.ChannelMetric,
index_mode: ChannelOrderCalibConfig.ChannelIndex,
group_size: int,
num_heads: int = 1,
num_head_repeats: int = 1,
) -> torch.Tensor:
"""Get the index of the channels.
Args:
metric (`torch.Tensor`):
The metric value of the channels.
metric_mode (`ChannelOrderCalibConfig.ChannelMetric`):
The channel metric mode.
index_mode (`ChannelOrderCalibConfig.ChannelIndex`):
The index mode.
group_size (`int`):
The quantization group size.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
Returns:
`torch.Tensor`:
The index of the channels.
"""
num_channels = metric.numel()
num_groups = num_channels // group_size
if num_heads > 1:
head_channels = num_channels // num_heads
if num_head_repeats > 1:
num_unique_heads = num_heads // num_head_repeats
metric = metric.view(num_unique_heads, num_head_repeats, head_channels)
metric = metric.amax(dim=1, keepdim=True) if "Max" in metric_mode.name else metric.sum(dim=1, keepdim=True)
rank = metric.argsort(dim=-1).expand(num_unique_heads, num_head_repeats, -1).reshape(num_heads, -1)
else:
rank = metric.view(num_heads, head_channels).argsort(dim=-1)
rank += torch.arange(0, num_channels, head_channels, dtype=torch.long, device=rank.device).view(num_heads, 1)
index = torch.empty_like(rank)
for head in range(num_heads):
index[head] = get_channel_index_from_rank(
rank[head],
num_channels=head_channels,
num_groups=max(num_groups // num_heads, 1),
index_mode=index_mode,
)
return index.view(-1)
else:
rank = metric.argsort()
return get_channel_index_from_rank(
rank, num_channels=num_channels, num_groups=num_groups, index_mode=index_mode
)
class ChannelOrderCalibrator(SearchBasedCalibrator[ChannelOrderCalibConfig, torch.Tensor]):
"""The calibrator for quantization channel reordering."""
def __init__(
self,
config: ChannelOrderCalibConfig,
weight_quantizer: Quantizer | None,
input_quantizer: Quantizer | None,
num_heads: int = 1,
num_head_repeats: int = 1,
develop_dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the calibrator.
Args:
config (`ChannelOrderCalibConfig`):
The channel order calibration configuration.
weight_quantizer (`Quantizer` or `None`):
The quantizer for the weights.
input_quantizer (`Quantizer` or `None`):
The quantizer for the inputs.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
"""
super().__init__(
tensor_type=TensorType.Weights,
config=config,
w_quantizer=weight_quantizer,
x_quantizer=input_quantizer,
y_quantizer=None,
develop_dtype=develop_dtype,
)
assert self.config.objective == SearchBasedCalibObjective.OutputsError
assert self.config.granularity == SearchBasedCalibGranularity.Layer
if self.config.strategy == SearchBasedCalibStrategy.Manual:
self.index_modes = [self.config.channel_index]
self.metric_modes = [self.config.channel_metric]
else:
self.metric_modes = list(ChannelOrderCalibConfig.ChannelMetric.__members__.values())
self.index_modes = list(ChannelOrderCalibConfig.ChannelIndex.__members__.values())
self.num_index_modes, self.num_metric_modes = len(self.index_modes), len(self.metric_modes)
self.num_heads = num_heads
self.num_head_repeats = num_head_repeats
self.metrics, self.channel_indexes = None, None
@property
def population_size(self) -> int:
"""Get the population size."""
size = self.num_index_modes * self.num_metric_modes
return (size + 1) if self.config.strategy != SearchBasedCalibStrategy.Manual else size
@property
def allows_x_quant_for_wgts(self) -> bool:
"""Whether the calibrator needs activation quantization when tensor_type is Weights."""
return self.config.allow_x_quant
@property
def allows_w_quant_for_wgts(self) -> bool:
"""Whether the calibrator needs weight quantization when tensor_type is Weights."""
return self.config.allow_w_quant
def update_channel_metrics(self, weights: list[torch.Tensor | nn.Parameter], inputs: TensorsCache) -> None:
"""Update the metrics of the channels.
Args:
weights (list[torch.Tensor | nn.Parameter]): The weight tensors.
inputs (TensorsCache): The input activations.
"""
weights = [w.data for w in weights]
if self.metrics is None:
self.num_channels = weights[0].shape[1]
self.device = weights[0].device
self.metrics = [None] * len(self.metric_modes)
for metric_id, metric_mode in enumerate(self.metric_modes):
self.metrics[metric_id] = update_channel_metric(
metric=self.metrics[metric_id],
inputs=inputs,
weights=weights,
metric_mode=metric_mode,
num_channels=self.num_channels,
num_heads=self.num_heads,
device=self.device,
dtype=self.develop_dtype,
)
def init_channel_indexes(self) -> None:
"""Initialize the indexes."""
if self.needs_x_quant:
ipts_group_size = self.x_quantizer.config.smallest_group_shape[1]
else:
ipts_group_size = -1
if ipts_group_size <= 0:
ipts_group_size = self.num_channels
if self.needs_w_quant:
wgts_group_size = self.w_quantizer.config.smallest_group_shape[1]
else:
wgts_group_size = -1
if wgts_group_size <= 0:
wgts_group_size = self.num_channels
group_size = min(ipts_group_size, wgts_group_size)
self.channel_indexes = [None] + [
init_channel_index_from_metric(
metric,
metric_mode=metric_mode,
index_mode=index_mode,
group_size=group_size,
num_heads=self.num_heads,
num_head_repeats=self.num_head_repeats,
)
for metric_mode, metric in zip(self.metric_modes, self.metrics, strict=True)
for index_mode in self.index_modes
]
self.arange = torch.arange(self.num_channels, dtype=torch.long, device=self.device)
self.metrics = None
gc.collect()
torch.cuda.empty_cache()
def _reset(self, x_wgts: list[torch.Tensor | nn.Parameter], x_acts: TensorsCache, **kwargs) -> None:
"""Reset the calibrator.
Args:
x_wgts (list[list[torch.Tensor | nn.Parameter]]): Weight tensors.
x_acts (TensorsCache): Input activations.
"""
if self.channel_indexes is None:
self.update_channel_metrics(x_wgts, x_acts)
self.init_channel_indexes()
if self.config.strategy == SearchBasedCalibStrategy.Manual and self.channel_indexes[0] is None:
self.channel_indexes = self.channel_indexes[1:]
assert len(self.channel_indexes) == self.population_size
self.baseline_errors, self.best_error, self.best_candidate_id = None, None, None
self.error_stats_history = []
def get_best(self) -> torch.Tensor:
"""Get the best candidate.
Returns:
torch.Tensor: The best candidate.
"""
return self.channel_indexes[self.best_candidate_id]
def _ask(self) -> torch.Tensor:
"""Ask for the next candidate.
Returns:
torch.Tensor: The next candidate.
"""
channel_index = self.channel_indexes[self.candidate_id]
channel_index_inverse = None
if channel_index is not None:
channel_index_inverse = torch.zeros_like(channel_index)
channel_index_inverse[channel_index] = self.arange.to(device=channel_index.device)
self.candidate_inverse = channel_index_inverse
return channel_index
def _tell(self, errors: list[tuple[torch.Tensor, ...]]) -> None: # noqa: C901
"""Tell the error of the last candidate and update the best candidate.
Args:
errors (list[tuple[torch.Tensor, ...]]): The error of the last candidate.
"""
errors = [tuple(math.root_(e.to(torch.float64), self.config.degree) for e in error) for error in errors]
if self.baseline_errors is None:
self.baseline_errors = errors
error_stats = [0, 0, 0, 0, 0]
for baseline_error, error in zip(self.baseline_errors, errors, strict=True):
for be, e in zip(baseline_error, error, strict=True):
_d = e.item() - be.item()
if e > be:
error_stats[0] += 1
if e < be:
error_stats[1] -= 1
error_stats[2] += max(_d, 0)
error_stats[3] += min(_d, 0)
error_stats[4] += e.item()
if self.best_error is None or error_stats < self.best_error:
self.best_error = error_stats
self.best_candidate_id = self.candidate_id
if self.logger.level <= tools.logging.DEBUG:
self.logger.debug(
f"+ {self._get_metric_index_mode_str(self.candidate_id)} : {self._get_error_str(error_stats)}"
)
if self.is_last_candidate_in_iter():
self.logger.debug(f"+ {self._get_metric_index_mode_str(self.best_candidate_id)} is the best candidate.")
def _get_error_str(self, e: list[int | float]) -> str:
return f"[{e[0]:+d}, {e[1]:+d}, {e[2]:>10.4f}, {e[3]:>10.4f}, {e[4]:>10.4f}]"
def _get_metric_index_mode_str(self, candidate_id: int) -> str:
if candidate_id == 0:
if self.config.strategy == SearchBasedCalibStrategy.Manual:
metric_mode, index_mode = self.metric_modes[0], self.index_modes[0]
else:
return f"{'baseline':>20} {'':>10}"
else:
metric_id = (candidate_id - 1) % self.num_metric_modes
index_id = (candidate_id - 1) // self.num_metric_modes
metric_mode, index_mode = self.metric_modes[metric_id], self.index_modes[index_id]
return f"{metric_mode.name:>20} - {index_mode.name:>10}"
def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if not self.needs_x_quant_for_wgts:
return x
if channels_dim is MISSING:
channels_dim = self.x_quantizer.channels_dim
if self.candidate is not None:
x = x.index_select(dim=channels_dim, index=self.candidate.to(x.device))
x = self.x_quantizer.quantize(x, channels_dim=channels_dim).data
if self.candidate is not None:
x = x.index_select(dim=channels_dim, index=self.candidate_inverse.to(x.device))
return x
def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
if not self.needs_w_quant_for_wgts:
return w
if self.candidate is not None:
w = w.index_select(dim=1, index=self.candidate.to(w.device))
w = self.w_quantizer.quantize(w.data, kernel=None, develop_dtype=self.develop_dtype).data
if self.candidate is not None:
w = w.index_select(dim=1, index=self.candidate_inverse.to(w.device))
return w
def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int) -> torch.Tensor:
raise RuntimeError("_process_x_in_yx should not be called in ChannelOrderCalibrator.")
def _process_y_in_yx(self, x: torch.Tensor, channels_dim: int) -> torch.Tensor:
raise RuntimeError("_process_y_in_yx should not be called in ChannelOrderCalibrator.")
def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_xw_in_yx should not be called in ChannelOrderCalibrator.")
def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("_process_yw_in_yx should not be called in ChannelOrderCalibrator.")
def _process_wgts_centric_mod(
self,
wgts: list[nn.Parameter],
mods: list[nn.Module],
*,
reorder_wgts: list[tuple[nn.Parameter, int]],
reorder_ipt_mods: list[tuple[nn.Module, int, BaseInputPackager | None]],
reorder_opt_mods: list[tuple[nn.Module, int, BaseOutputPackager | None]],
update_state_dict: bool = True,
**kwargs,
) -> None:
channels_index = self.candidate
if update_state_dict:
self._state_dict.extend([(w, w.data) for w, _ in reorder_wgts])
if channels_index is not None:
for w, d in reorder_wgts:
w.data = w.data.index_select(dim=d, index=channels_index.to(w.device))
for m, channels_dim, packager in reorder_ipt_mods:
self._hooks.append(
ChannelReorderer(channels_index, channels_dim, input_packager=packager).as_hook().register(m)
)
for m, channels_dim, packager in reorder_opt_mods:
self._hooks.append(
ChannelReorderer(channels_index, channels_dim, output_packager=packager)
.as_hook(is_output=True)
.register(m)
)
self._candidate_backup = channels_index
self.candidate = None # we have already reordered and thus do not need to reorder again in _process
super()._process_wgts_centric_mod(wgts, mods, update_state_dict=False)
def _recover_mod(self) -> None:
super()._recover_mod()
self.candidate = self._candidate_backup
self._candidate_backup = None
================================================
FILE: deepcompressor/calib/rotate.py
================================================
# -*- coding: utf-8 -*-
"""Rotation Quantization module."""
import typing as tp
import torch
import torch.nn as nn
from ..utils.hooks import BaseInputPackager, IOHook
from ..utils.math import HadamardMatrix, hardmard_transform, random_hadamard_matrix
__all__ = [
"rotate_in_channels",
"rotate_out_channels",
"hadamard_in_channels",
"get_rotation_matrix",
"transform_rms_norm_and_linear",
"transform_layer_norm_to_rms_norm",
"transform_norm_and_linear",
]
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (RMSNorm)."""
def __init__(self, hidden_size: int, eps=1e-6) -> None:
"""Initialize RMSNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply RMSNorm normalization to hidden states."""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class HadamardTransformHook(IOHook):
def __init__(
self, rhs: torch.Tensor, lhs: torch.Tensor, lhs_k: int, scaled: bool = True, packager: BaseInputPackager = None
):
super().__init__(pre=True, post=False, input_packager=packager, output_packager=None)
self.rhs = rhs
self.lhs = lhs
self.lhs_k = lhs_k
self.scaled = scaled
def pre_forward(
self,
module: nn.Module,
input_args: tuple[torch.Tensor, ...],
input_kwargs: dict[str, tp.Any],
) -> tuple[tuple[torch.Tensor, ...], dict[str, tp.Any]]:
tensors = self.input_packager.unpack(module, input_args, input_kwargs)
for k, x in tensors.items():
tensors[k] = hardmard_transform(
x, hadamard_rhs=self.rhs, hadamard_lhs=self.lhs, lhs_k=self.lhs_k, scaled=self.scaled
)
return self.input_packager.repack(tensors, module, input_args, input_kwargs)
def rotate_in_channels(weight: nn.Parameter, /, *, rotation: torch.Tensor) -> None:
"""Rotate the input channels of a weight matrix."""
shape, dtype = weight.shape, weight.dtype
weight.data = (
torch.matmul(weight.data.view(-1, rotation.shape[0]).to(dtype=torch.float64), rotation.to(weight.device))
.to(dtype=dtype)
.view(shape)
)
def rotate_out_channels(weight: nn.Parameter, /, *, rotation: torch.Tensor, bias: nn.Parameter | None = None) -> None:
"""Rotate the output channels of a weight matrix."""
shape, dtype = weight.shape, weight.dtype
out_channels, head_channels = shape[0], rotation.shape[0]
num_heads = out_channels // head_channels
weight.data = (
torch.matmul(
rotation.T.to(weight.device), weight.data.view(num_heads, head_channels, -1).to(dtype=torch.float64)
)
.to(dtype=dtype)
.view(shape)
)
if bias is not None:
bias.data = (
torch.matmul(
rotation.T.to(weight.device), bias.data.view(num_heads, head_channels, -1).to(dtype=torch.float64)
)
.to(dtype=dtype)
.view(-1)
)
def hadamard_in_channels(
modules: tp.Iterable[nn.Module],
packager: BaseInputPackager = None,
dtype: torch.dtype | None = None,
device: torch.device | str | None = None,
):
"""Apply Hadamard quantization to the input channels of the modules."""
for module in modules:
if isinstance(module, nn.Linear):
in_channels = module.in_features
device, dtype = device or module.weight.device, dtype or module.weight.dtype
rhs_double, lhs_double, k = HadamardMatrix.get(in_channels, scale=True, dtype=torch.float64)
module.weight.data = hardmard_transform(
module.weight.data.to(torch.float64), rhs_double.to(device), lhs_double.to(device), k, scaled=True
).to(device=device, dtype=module.weight.dtype)
del rhs_double, lhs_double, k
rhs, lhs, k = HadamardMatrix.get(in_channels, scale=True, dtype=dtype, device=device)
HadamardTransformHook(rhs=rhs, lhs=lhs, lhs_k=k, packager=packager).register(module)
else:
raise NotImplementedError(f"Module {module} not supported!")
def get_rotation_matrix(num_channels: int, random: bool = True, compatible: bool = True) -> torch.Tensor:
"""Get a random rotation matrix for the given number of channels."""
if random:
return random_hadamard_matrix(num_channels)
else:
rhs, lhs, k = HadamardMatrix.get(num_channels, scale=False)
rhs = rhs.to(dtype=torch.float64)
if k == 1:
rotation = rhs
elif compatible: # this is compatible with hadamard_transform
rotation = torch.kron(lhs.T.contiguous().to(dtype=torch.float64), rhs)
else:
rotation = torch.kron(rhs, lhs.to(dtype=torch.float64))
return rotation.mul_(1.0 / torch.tensor(num_channels, dtype=torch.float64).sqrt())
def transform_rms_norm_and_linear(norm: nn.LayerNorm | RMSNorm, next_modules: tp.Iterable[nn.Linear]) -> None:
"""Fuse the weight multiplication of rms norm into the next adjacent linear modules.
Args:
norm (`nn.LayerNorm` or `RMSNorm`):
normalization module.
next_modules (`Iterable[nn.Linear]`):
modules after the normalization module.
"""
ln_w = norm.weight.data.to(dtype=torch.float64)
norm.weight.data = torch.ones_like(norm.weight.data)
if hasattr(norm, "bias") and norm.bias is not None:
ln_b = norm.bias.data.to(dtype=torch.float64)
norm.bias = None
else:
ln_b = None
for linear in next_modules:
assert isinstance(linear, nn.Linear)
dtype = linear.weight.dtype
fc_w = linear.weight.data.to(dtype=torch.float64)
ln_w = ln_w.to(fc_w.device)
linear.weight.data = (fc_w * ln_w).to(dtype=dtype)
if ln_b is not None:
ln_b = ln_b.to(fc_w.device)
if linear.bias is None:
linear.bias = nn.Parameter(torch.zeros(linear.out_features, dtype=dtype, device=linear.weight.device))
linear.bias.data = (linear.bias.data.to(dtype=torch.float64) + torch.matmul(fc_w, ln_b)).to(dtype=dtype)
def transform_layer_norm_to_rms_norm(
parent: nn.Module,
norm_name: str,
prev_modules: tp.Iterable[nn.Linear],
prev_out_channels_dims: int | tp.Iterable[int] = 0,
) -> None:
"""Transform LayerNorm to RMSNorm.
Args:
parent (`nn.Module`):
Parent module that contains the normalization module.
norm_name (`str`):
Name of the normalization module in `parent`.
prev_modules (`Iterable[nn.Linear]`):
Previous adjacent linear modules.
prev_out_channels_dims (`int` or `Iterable[int]`, *optional*, defaults to `0`):
Output channels dimension of the previous modules' weights.
"""
if "." in norm_name:
norm_names = norm_name.split(".")
for name in norm_names[:-1]:
parent = getattr(parent, name)
norm_name = norm_names[-1]
del norm_names
norm = getattr(parent, norm_name)
assert isinstance(norm, nn.LayerNorm)
assert len(norm.normalized_shape) == 1, f"LayerNorm's #dims must be 1, got {len(norm.normalized_shape)}"
assert norm.bias is None, "LayerNorm's bias must be None, please call `transform_rms_norm_and_linear` in advance"
# region move substract mean to the previous linear modules
assert len(prev_modules) > 0, "No previous modules found"
if isinstance(prev_out_channels_dims, int):
prev_out_channels_dims = [prev_out_channels_dims] * len(prev_modules)
for module, dim in zip(prev_modules, prev_out_channels_dims, strict=True):
if isinstance(module, nn.LayerNorm):
module.bias = None
else:
if isinstance(module, nn.Linear):
assert dim == 0, "Linear module's output channels dimension is 0"
elif isinstance(module, nn.Embedding):
assert dim == 1, "Embedding module's output channels dimension is 1"
dtype = module.weight.dtype
w = module.weight.data.to(dtype=torch.float64)
module.weight.data = w.sub_(w.mean(dim=dim, keepdim=True)).to(dtype=dtype)
if hasattr(module, "bias") and module.bias is not None:
b = module.bias.data.to(dtype=torch.float64)
module.bias.data = b.sub_(b.mean()).to(dtype=dtype)
# endregion
# region replace LayerNorm with RMSNorm
rms = RMSNorm(hidden_size=norm.normalized_shape[0], eps=norm.eps)
rms.weight.data = norm.weight.data
setattr(parent, norm_name, rms)
# endregion
def transform_norm_and_linear(
parent: nn.Module,
norm_name: str,
next_modules: tp.Iterable[nn.Linear],
prev_modules: tp.Iterable[nn.Linear] | None = None,
prev_out_channels_dims: int | tp.Iterable[int] = 0,
):
"""Transform the normalization module and the next adjacent linear modules.
Args:
parent (nn.Module): Parent module.
norm_name (str): Name of the normalization module.
next_modules (tp.Iterable[nn.Linear]): Next adjacent linear modules.
prev_modules (tp.Iterable[nn.Linear]): Previous adjacent linear modules.
prev_out_channels_dims (int | tp.Iterable[int], optional): Output channels dimension of the previous modules.
Defaults to ``0``.
"""
if "." in norm_name:
norm_names = norm_name.split(".")
for name in norm_names[:-1]:
parent = getattr(parent, name)
norm_name = norm_names[-1]
del norm_names
norm = getattr(parent, norm_name)
transform_rms_norm_and_linear(norm, next_modules)
if isinstance(norm, nn.LayerNorm):
transform_layer_norm_to_rms_norm(parent, norm_name, prev_modules, prev_out_channels_dims)
================================================
FILE: deepcompressor/calib/search.py
================================================
# -*- coding: utf-8 -*-
"""Search-based uantization calibrator module."""
import gc
import typing as tp
from abc import ABC, abstractmethod
from dataclasses import _MISSING_TYPE, MISSING
import psutil
import torch
import torch.nn as nn
import torch.utils.hooks
from ..data.cache import TensorCache, TensorsCache
from ..data.common import TensorType
from ..data.utils.reshape import ReshapeFn
from ..data.utils.shape import infer_view_shape
from ..quantizer.processor import Quantizer
from ..utils import tools
from ..utils.hooks import Hook
from .config import SearchBasedCalibConfig, SearchBasedCalibGranularity, SearchBasedCalibObjective
__all__ = ["SearchBasedCalibrator"]
def _reshape_w_for_wgts(w: torch.Tensor, w_view_shape: torch.Size) -> torch.Tensor:
# (#g0, gs0, #g1, gs1, ...)
w = w.view(w_view_shape)
# (#g0, gs0, #g1, gs1, ...) -> (#g0, ..., gs1, ..., gs0)
w = w.permute(*range(0, len(w_view_shape), 2), *range(3, len(w_view_shape), 2), 1)
# (#g0, ..., gs0, gs1, ...) -> (#g0, ..., gs1 * gs2 * ..., gs0)
return w.reshape(*w_view_shape[::2], -1, w_view_shape[1])
def _reshape_x_for_wgts(x: torch.Tensor, w_view_shape: torch.Size) -> torch.Tensor:
# x is unfolded already
num_samples = x.shape[0]
# (1, n, #g1, gs1, ...)
x = x.view(1, num_samples, *w_view_shape[2:])
# (1, n, #g1, gs1, ...) -> (1, #g1, ..., n, gs1, ...)
x = x.permute(*range(0, len(w_view_shape), 2), *range(1, len(w_view_shape), 2))
return x.reshape(1, *w_view_shape[2::2], num_samples, -1)
def _reshape_x_for_ipts(x: torch.Tensor, x_view_shape: torch.Size) -> torch.Tensor:
# x is original tensor without unfolding
# (#g0, gs0, #g1, gs1, ...)
x = x.view(x_view_shape)
# (#g0, gs0, #g1, gs1, ...) -> (#g0, #g1, ..., gs0, gs2, ..., gs1)
x = x.permute(*range(0, len(x_view_shape), 2), 1, *range(5, len(x_view_shape), 2), 3)
# (#g0, #g1, ..., gs0, gs2, ..., gs1) -> (#g0, #g1, ..., gs0 * gs2 * ..., gs1)
return x.reshape(*x_view_shape[::2], -1, x_view_shape[3])
def _reshape_w_for_ipts(w: torch.Tensor, x_view_shape: torch.Size) -> torch.Tensor:
return w.transpose(0, 1).reshape(1, x_view_shape[2], *([1] * (w.ndim - 2)), x_view_shape[3], -1)
_CANDIDATE = tp.TypeVar("_CANDIDATE")
_CONFIG = tp.TypeVar("_CONFIG", bound=SearchBasedCalibConfig)
class SearchBasedCalibrator(ABC, tp.Generic[_CONFIG, _CANDIDATE]):
"""The base class for search-based calibration."""
config: _CONFIG
candidate: _CANDIDATE
def __init__(
self,
tensor_type: TensorType,
config: _CONFIG,
w_quantizer: Quantizer | None,
x_quantizer: Quantizer | None,
y_quantizer: Quantizer | None,
develop_dtype: torch.dtype,
) -> None:
"""Initialize the search-based calibrator.
Args:
tensor_type (`TensorType`):
The tensor type.
config (`_CONFIG`):
The calibration configuration.
w_quantizer (`Quantizer` or `None`):
The w quantizer for x-w computation.
x_quantizer (`Quantizer` or `None`):
The x quantizer for x-w or y-x computation.
y_quantizer (`Quantizer` or `None`):
The y quantizer for y-x computation.
develop_dtype (`torch.dtype`):
The development data type.
"""
self.tensor_type = tensor_type
self.config = config
self.objective = self.config.objective
self.granularity = self.config.granularity
self.opts_device = None
self.develop_dtype = develop_dtype
self.w_quantizer = w_quantizer
self.x_quantizer = x_quantizer
self.y_quantizer = y_quantizer
self.needs_w_quant = self.w_quantizer is not None and self.w_quantizer.is_enabled()
self.needs_x_quant = self.x_quantizer is not None and self.x_quantizer.is_enabled()
self.needs_y_quant = self.y_quantizer is not None and self.y_quantizer.is_enabled()
self.needs_x_quant_for_wgts = self.allows_x_quant_for_wgts and self.needs_x_quant
self.needs_w_quant_for_wgts = self.allows_w_quant_for_wgts and self.needs_w_quant
self.needs_x_quant_for_ipts = self.allows_x_quant_for_ipts and self.needs_x_quant
self.needs_w_quant_for_ipts = self.allows_w_quant_for_ipts and self.needs_w_quant
self.needs_x_quant_for_opts = self.allows_x_quant_for_opts and self.needs_x_quant
self.needs_y_quant_for_opts = self.allows_y_quant_for_opts and self.needs_y_quant
self.needs_w_quant_for_opts = self.allows_w_quant_for_opts and self.needs_w_quant
if self.tensor_type == TensorType.Weights:
self.quantizer = self.w_quantizer
self.needs_quant = self.needs_w_quant
elif self.tensor_type == TensorType.Inputs:
self.quantizer = self.x_quantizer
self.needs_quant = self.needs_x_quant
elif self.tensor_type == TensorType.Outputs:
self.quantizer = self.y_quantizer
self.needs_quant = self.needs_y_quant
else:
raise ValueError(f"unknown tensor type: {self.tensor_type}")
self.num_iters = getattr(self.config, "num_iters", 1)
self.logger = tools.logging.getLogger(f"{__name__}.{self.__class__.__name__.replace('Agent', '')}")
@property
@abstractmethod
def population_size(self) -> int:
"""Get the population size."""
...
@property
def allows_x_quant_for_wgts(self) -> bool:
"""Whether the calibrator allows input quantization when tensor_type is Weights."""
return False
@property
def allows_w_quant_for_wgts(self) -> bool:
"""Whether the calibrator allows weight quantization when tensor_type is Weights."""
return True
@property
def allows_x_quant_for_ipts(self) -> bool:
"""Whether the calibrator allows input quantization when tensor_type is Inputs."""
return True
@property
def allows_w_quant_for_ipts(self) -> bool:
"""Whether the calibrator allows weight quantization when tensor_type is Inputs."""
return False
@property
def allows_x_quant_for_opts(self) -> bool:
"""Whether the calibrator allows x quantization when tensor_type is Outputs."""
return True
@property
def allows_y_quant_for_opts(self) -> bool:
"""Whether the calibrator allows y quantization when tensor_type is Outputs."""
return True
@property
def allows_w_quant_for_opts(self) -> bool:
"""Whether the calibrator allows weight quantization when tensor_type is Outputs."""
return False
@property
def needs_to_pre_reshape_x_for_wgts(self) -> bool:
"""Whether the calibrator needs to pre-reshape the inputs for weight quantization calibration."""
return not self.needs_x_quant_for_wgts and self.config.pre_reshape
@property
def needs_to_pre_reshape_w_for_ipts(self) -> bool:
"""Whether the calibrator needs to pre-reshape the weights for input quantization calibration."""
return not self.needs_w_quant_for_ipts and self.config.pre_reshape
def _reset(self, **kwargs) -> None:
pass
def reset(self, **kwargs) -> None:
"""Reset the calibrator."""
self.iter = 0
self.candidate_id = 0
self._reset(**kwargs)
self._state_dict: list[tuple[nn.Parameter, torch.Tensor]] = []
self._hooks: list[Hook | torch.utils.hooks.RemovableHandle] = []
def is_done(self) -> bool:
"""Check if the calibration is done."""
return self.iter >= self.num_iters
def is_last_iter(self) -> bool:
"""Check if the current iteration is the last one."""
return self.iter == self.num_iters - 1
def is_last_candidate_in_iter(self) -> bool:
"""Check if the current candidate is the last one in the current iteration."""
return self.candidate_id == self.population_size - 1
@abstractmethod
def get_best(self) -> _CANDIDATE:
"""Get the best candidate.
Returns:
`_CANDIDATE`:
The best candidate.
"""
...
@abstractmethod
def _ask(self) -> _CANDIDATE:
"""Ask for the next candidate.
Returns:
`_CANDIDATE`:
The next candidate.
"""
...
@abstractmethod
def _tell(self, error: list[torch.Tensor]) -> None:
"""Tell the error of the last candidate and update the best candidate.
Args:
error (`list[torch.Tensor]`):
The error of the last candidate.
"""
...
def ask(self) -> _CANDIDATE:
"""Ask for the next candidate.
Returns:
`_CANDIDATE`:
The next candidate.
"""
self.candidate = self._ask()
return self.candidate
def tell(self, error: list[torch.Tensor]) -> None:
"""Tell the error of the last candidate and update the best candidate.
Args:
error (`list[torch.Tensor]`):
The error of the last candidate.
"""
self._tell(error)
self.candidate_id += 1
if self.candidate_id >= self.population_size:
self.iter += 1
self.candidate_id = 0
def _parse_ipts(self, ipts: TensorsCache | None, set_device: bool = False) -> TensorsCache | None:
if set_device:
self.opts_device = None
elif ipts is None:
return None
if self.objective == SearchBasedCalibObjective.ProductsError:
batch_size = self.config.element_batch_size
calib_size = self.config.element_size
elif self.objective == SearchBasedCalibObjective.OutputsError:
batch_size = self.config.sample_batch_size
calib_size = self.config.sample_size
else:
assert self.objective == SearchBasedCalibObjective.TensorError
batch_size = -1
calib_size = -1
prev_size = len(ipts.front().data)
parsed_ipts = TensorsCache(
{
key: ipt.repartition(
max_batch_size=batch_size,
max_size=calib_size,
standardize=self.objective == SearchBasedCalibObjective.ProductsError,
reshape=self.tensor_type == TensorType.Weights,
)
for key, ipt in ipts.items()
}
)
curr_size = len(parsed_ipts.front().data)
assert all(len(ipt.data) == curr_size for ipt in parsed_ipts.values())
if set_device and prev_size != curr_size:
self.opts_device = self.config.outputs_device
return parsed_ipts
def _parse_args( # noqa: C901
self,
x_wgts: list[nn.Parameter] | None,
y_wgts: list[nn.Parameter] | None,
x_acts: TensorsCache | None,
y_acts: TensorsCache | None,
eval_inputs: TensorsCache | None,
eval_module: nn.Module | None,
x_mods: list[nn.Module] | None,
y_mods: list[nn.Module] | None,
orig_x_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_y_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_x_acts: TensorsCache | None,
orig_y_acts: TensorsCache | None,
orig_eval_inputs: TensorsCache | None,
) -> tuple[
list[torch.Tensor | nn.Parameter] | None, # x_wgts
list[torch.Tensor | nn.Parameter] | None, # y_wgts
TensorsCache | None, # x_acts
TensorsCache | None, # y_acts
TensorsCache | None, # eval_inputs
nn.Module | None, # eval_module
list[nn.Module] | None, # x_mods
list[nn.Module] | None, # y_mods
list[tuple[nn.Parameter, torch.Tensor]] | None, # orig_x_wgts
list[tuple[nn.Parameter, torch.Tensor]] | None, # orig_y_wgts
TensorCache | None, # orig_x_acts
TensorCache | None, # orig_y_acts
TensorCache | None, # orig_eval_inputs
]:
# region Check the types of the arguments
if x_wgts is not None:
assert isinstance(x_wgts, (tuple, list)), "x_wgts should be a list"
assert all(isinstance(w, nn.Parameter) for w in x_wgts), "wgts should be a list of nn.Parameter"
if y_wgts is not None:
assert isinstance(y_wgts, (tuple, list)), "y_wgts should be a list"
assert all(isinstance(w, nn.Parameter) for w in y_wgts), "wgts should be a list of nn.Parameter"
if x_acts is not None:
assert isinstance(x_acts, TensorsCache), "x_acts should be a TensorsCache"
if y_acts is not None:
assert isinstance(y_acts, TensorsCache), "y_acts should be a TensorsCache"
if eval_inputs is not None:
assert isinstance(eval_inputs, TensorsCache), "eval_inputs should be a TensorsCache"
if x_mods is not None:
assert isinstance(x_mods, (tuple, list)), "x_mods should be a list"
if y_mods is not None:
assert isinstance(y_mods, (tuple, list)), "y_mods should be a list"
if orig_x_wgts is not None:
assert isinstance(orig_x_wgts, (tuple, list)), "orig_x_wgts should be a list"
assert all(isinstance(p, nn.Parameter) and isinstance(w, torch.Tensor) for p, w in orig_x_wgts), (
"orig_x_wgts should be a list of tuples of nn.Parameter and torch.Tensor"
)
if x_wgts is not None:
assert len(orig_x_wgts) >= len(x_wgts), "orig_wgts should have at least as mtp.Any elements as wgts"
assert all(p is w for (p, _), w in zip(orig_x_wgts, x_wgts, strict=False)), (
"the parameters in orig_wgts should be in wgts in the same order"
)
if orig_y_wgts is not None:
assert isinstance(orig_y_wgts, (tuple, list)), "orig_y_wgts should be a list"
assert all(isinstance(p, nn.Parameter) and isinstance(w, torch.Tensor) for p, w in orig_y_wgts), (
"orig_y_wgts should be a list of tuples of nn.Parameter and torch.Tensor"
)
if y_wgts is not None:
assert len(orig_y_wgts) >= len(y_wgts), "orig_wgts should have at least as mtp.Any elements as wgts"
assert all(p is w for (p, _), w in zip(orig_y_wgts, y_wgts, strict=False)), (
"the parameters in orig_wgts should be in wgts in the same order"
)
if orig_x_acts is not None:
assert isinstance(orig_x_acts, TensorsCache), "orig_x_acts should be a TensorsCache"
if orig_y_acts is not None:
assert isinstance(orig_y_acts, TensorsCache), "orig_y_acts should be a TensorsCache"
if orig_eval_inputs is not None:
assert isinstance(orig_eval_inputs, TensorsCache), "orig_eval_inputs should be a TensorsCache"
# endregion
self.objective = self.config.objective
self.granularity = self.config.granularity
if self.tensor_type == TensorType.Outputs:
# ! currently only support OutputsError and Layer granularity for Outputs
self.objective = SearchBasedCalibObjective.OutputsError
self.granularity = SearchBasedCalibGranularity.Layer
if self.objective == SearchBasedCalibObjective.TensorError:
if x_wgts is not None:
x_wgts = [w.detach().data for w in x_wgts]
if y_wgts is not None:
y_wgts = [w.detach().data for w in y_wgts]
if self.tensor_type == TensorType.Weights:
assert x_wgts is not None, "wgts should not be None when tensor_type is Weights"
elif self.tensor_type == TensorType.Inputs:
assert x_acts is not None, "mod_ipts should not be None when tensor_type is Inputs"
eval_inputs, orig_eval_inputs = x_acts, orig_x_acts
else: # self.tensor_type == TensorType.Outputs
assert y_acts is not None, "opts should not be None when tensor_type is Outputs"
eval_inputs, orig_eval_inputs = y_acts, orig_y_acts
eval_module = None
elif self.objective == SearchBasedCalibObjective.ProductsError:
assert self.tensor_type in (
TensorType.Weights,
TensorType.Inputs,
), "tensor_type should be Weights or Inputs when objective is ProductsError"
assert x_wgts is not None, "wgts should not be None when objective is ProductsError"
x_wgts = [w.detach().data for w in x_wgts]
if y_wgts is not None:
y_wgts = [w.detach().data for w in y_wgts]
x_acts = x_acts or eval_inputs
orig_x_acts = orig_x_acts or orig_eval_inputs
assert x_acts is not None, "x_acts should not be None when objective is ProductsError"
eval_inputs, orig_eval_inputs = x_acts, orig_x_acts
elif self.objective == SearchBasedCalibObjective.OutputsError:
assert eval_inputs is not None, "eval_inputs should not be None when objective is OutputsError"
assert eval_module is not None, "eval_module should not be None when OutputsError"
if (
isinstance(eval_module, (nn.Linear, nn.Conv2d))
and self.granularity.value < SearchBasedCalibGranularity.Layer.value
and self.tensor_type != TensorType.Outputs
):
self.objective = SearchBasedCalibObjective.ProductsError
x_wgts = [w.detach().data for w in x_wgts]
if y_wgts is not None:
y_wgts = [w.detach().data for w in y_wgts]
x_acts = x_acts or eval_inputs
orig_x_acts = orig_x_acts or orig_eval_inputs
assert x_acts is not None, "x_acts should not be None when objective is ProductsError"
eval_inputs, orig_eval_inputs = x_acts, orig_x_acts
else:
self.objective = SearchBasedCalibObjective.OutputsError
self.granularity = SearchBasedCalibGranularity.Layer
else:
raise ValueError(f"unknown objective: {self.objective}")
self.logger.debug(
f"+ tensor_type: {self.tensor_type}, objective: {self.objective}, granularity: {self.granularity}"
)
return (
x_wgts,
y_wgts,
x_acts,
y_acts,
self._parse_ipts(eval_inputs, set_device=True),
eval_module,
x_mods,
y_mods,
orig_x_wgts,
orig_y_wgts,
orig_x_acts,
orig_y_acts,
self._parse_ipts(orig_eval_inputs),
)
# region Reshape functions for computing products
def _reshape_w_for_wgts_centric_partial_products(self, w: torch.Tensor, *, view_shape: torch.Size) -> torch.Tensor:
return _reshape_w_for_wgts(w, view_shape)
def _reshape_x_for_wgts_centric_partial_products(
self, x: torch.Tensor, *, view_shape: torch.Size, fn: ReshapeFn
) -> torch.Tensor:
return _reshape_x_for_wgts(fn(x), view_shape)
def _reshape_w_for_ipts_centric_partial_products(self, w: torch.Tensor, *, view_shape: torch.Size) -> torch.Tensor:
return _reshape_w_for_ipts(w, view_shape)
def _reshape_x_for_ipts_centric_partial_products(
self, x: torch.Tensor, *, view_shape: torch.Size, fn: ReshapeFn = None
) -> torch.Tensor:
return _reshape_x_for_ipts(x, view_shape)
def _reshape_w_for_full_products(self, w: torch.Tensor, *, view_shape: torch.Size = None) -> torch.Tensor:
return w.view(w.shape[0], -1).T
def _reshape_x_for_full_products(
self, x: torch.Tensor, *, fn: ReshapeFn, view_shape: torch.Size = None
) -> torch.Tensor:
return fn(x).view(x.shape[0], -1)
# endregion
@abstractmethod
def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor: ...
@abstractmethod
def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor: ...
@abstractmethod
def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor: ...
@abstractmethod
def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor: ...
def _recover_mod(self) -> None:
for p, w in self._state_dict:
p.data = w
self._state_dict.clear()
for hook in self._hooks:
hook.remove()
self._hooks.clear()
def _process_wgts_centric_mod(
self, wgts: list[nn.Parameter], mods: list[nn.Module], update_state_dict: bool = True, **kwargs
) -> None:
if self.needs_w_quant_for_wgts:
for w in wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
w.data = self._process_w_in_xw(w.data)
if self.needs_x_quant_for_wgts:
self._hooks.append(self.x_quantizer.as_hook(func=self._process_x_in_xw, is_output=False).register(mods))
def _process_ipts_centric_mod(
self, wgts: list[nn.Parameter], mods: list[nn.Module], update_state_dict: bool = True, **kwargs
) -> None:
if self.needs_w_quant_for_ipts:
for w in wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
w.data = self._process_w_in_xw(w.data)
if self.needs_x_quant_for_ipts:
self._hooks.append(self.x_quantizer.as_hook(self._process_x_in_xw, is_output=False).register(mods))
def _process_opts_centric_mod(
self,
x_wgts: list[nn.Parameter],
y_wgts: list[nn.Parameter],
x_mods: list[nn.Module],
y_mods: list[nn.Module],
update_state_dict: bool = True,
**kwargs,
) -> None:
if self.needs_w_quant_for_opts:
for w in x_wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
w.data = self._process_xw_in_yx(w.detach().data)
for w in y_wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
w.data = self._process_yw_in_yx(w.detach().data)
if self.needs_x_quant_for_opts:
self._hooks.append(self.x_quantizer.as_hook(self._process_x_in_yx, is_output=True).register(x_mods))
if self.needs_y_quant_for_opts:
self._hooks.append(self.y_quantizer.as_hook(self._process_y_in_yx, is_output=True).register(y_mods))
def calibrate(
self,
x_wgts: list[nn.Parameter] | None = None,
y_wgts: list[nn.Parameter] | None = None,
x_acts: TensorsCache | None = None,
y_acts: TensorsCache | None = None,
x_mods: list[nn.Module] | None = None,
y_mods: list[nn.Module] | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
orig_x_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None = None,
orig_y_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None = None,
orig_x_acts: TensorsCache | None = None,
orig_y_acts: TensorsCache | None = None,
orig_eval_inputs: TensorsCache | None = None,
**kwargs,
) -> _CANDIDATE:
"""Calibrate the quantization parameters.
Args:
x_wgts (`list[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights in x-w computation, or weights that generates x for y-x computation.
y_wgts (`list[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights that generates y for y-x computation.
x_acts (`TensorsCache` or `None`, *optional*, defaults to `None`):
The x activations. It should be x for x-w or y-x computation.
y_acts (`TensorsCache` or `None`, *optional*, defaults to `None`):
The y activations. It should be y for y-x computation.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The inputs of evaluation module `eval_module`.
eval_module (`nn.Module` or `None`, *optional*, defaults to `None`):
The module used for evaluation.
x_mods (`list[nn.Module]` or `None`, *optional*, defaults to `None`):
The modules for x activation quantization.
It should be the modules that take in x for x-w computation,
or the modules that generates x for y-x computation.
y_mods (`list[nn.Module]` or `None`, *optional*, defaults to `None`):
The modules for y activation quantization.
It should be the modules that generates y for y-x computation.
orig_x_wgts (`list[tuple[nn.Parameter, torch.Tensor]]` or `None`, *optional*, defaults to `None`):
The original weights for `x_mods`.
orig_y_wgts (`list[tuple[nn.Parameter, torch.Tensor]]` or `None`, *optional*, defaults to `None`):
The original weights for `y_mods`.
orig_x_acts (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original x activations `x_acts`.
orig_y_acts (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original y activations `y_acts`.
orig_eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The original inputs of evaluation module `eval_inputs`.
eval_kwargs (`dict[str, tp.Any]` or `None`, *optional*, defaults to `None`):
The keyword arguments for evaluation module `eval_module`.
Returns:
`_CANDIDATE`:
The best candidate.
"""
tools.logging.Formatter.indent_inc()
if self.w_quantizer is not None and self.w_quantizer.is_enabled():
self.logger.debug(f"+ w: {self.w_quantizer.config.quant_dtype}")
else:
self.logger.debug("+ w: None")
if self.x_quantizer is not None and self.x_quantizer.is_enabled():
self.logger.debug(f"+ x: {self.x_quantizer.config.quant_dtype}")
else:
self.logger.debug("+ x: None")
if self.y_quantizer is not None and self.y_quantizer.is_enabled():
self.logger.debug(f"+ y: {self.y_quantizer.config.quant_dtype}")
else:
self.logger.debug("+ y: None")
(
x_wgts,
y_wgts,
x_acts,
y_acts,
eval_inputs,
eval_module,
x_mods,
y_mods,
orig_x_wgts,
orig_y_wgts,
orig_x_acts,
orig_y_acts,
orig_eval_inputs,
) = self._parse_args(
x_wgts,
y_wgts,
x_acts,
y_acts,
eval_inputs,
eval_module,
x_mods,
y_mods,
orig_x_wgts,
orig_y_wgts,
orig_x_acts,
orig_y_acts,
orig_eval_inputs,
)
eval_kwargs = eval_kwargs or {}
self.logger.debug(f"+ finished parsing calibration arguments, ram usage: {psutil.virtual_memory().percent}")
self.reset(
x_wgts=x_wgts,
y_wgts=y_wgts,
x_acts=x_acts,
y_acts=y_acts,
eval_inputs=eval_inputs,
eval_module=eval_module,
x_mods=x_mods,
y_mods=y_mods,
orig_x_wgts=orig_x_wgts,
orig_y_wgts=orig_y_wgts,
orig_x_acts=orig_x_acts,
orig_y_acts=orig_y_acts,
orig_eval_inputs=orig_eval_inputs,
eval_kwargs=eval_kwargs,
**kwargs,
)
self.logger.debug(f"+ finished resetting calibrator, ram usage: {psutil.virtual_memory().percent}")
gc.collect()
torch.cuda.empty_cache()
if self.tensor_type == TensorType.Weights:
result = self._calibrate_wgts(
x_wgts, eval_inputs, eval_module, x_mods, orig_x_wgts, orig_eval_inputs, eval_kwargs, **kwargs
)
elif self.tensor_type == TensorType.Inputs:
result = self._calibrate_ipts(
x_wgts, eval_inputs, eval_module, x_mods, orig_x_wgts, orig_eval_inputs, eval_kwargs, **kwargs
)
else:
result = self._calibrate_opts(
x_wgts,
y_wgts,
eval_inputs,
eval_module,
x_mods,
y_mods,
orig_x_wgts,
orig_y_wgts,
orig_eval_inputs,
eval_kwargs,
**kwargs,
)
tools.logging.Formatter.indent_dec()
return result
def _calibrate_wgts( # noqa: C901
self,
wgts: list[torch.Tensor | nn.Parameter],
ipts: TensorsCache | None,
eval_module: nn.Module | None,
mods: list[nn.Module] | None,
orig_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_ipts: TensorsCache | None,
eval_kwargs: dict[str, tp.Any],
**kwargs,
) -> tp.Any:
# region Step 1: Calculate the baseline
if self.objective == SearchBasedCalibObjective.TensorError:
if orig_wgts is None:
orig_wgts = [(None, w.detach().data) for w in wgts]
assert all(w.shape[1:] == wgts[0].shape[1:] for w in wgts)
assert all(w.shape[1:] == wgts[0].shape[1:] for _, w in orig_wgts)
orig_opts = None
w_view_shapes = [infer_view_shape(w.shape, self.w_quantizer.config.largest_group_shape) for w in wgts]
elif self.objective == SearchBasedCalibObjective.ProductsError:
if orig_wgts is None:
orig_wgts = [(None, w.detach().data) for w in wgts]
assert len(orig_wgts) == len(wgts)
assert all(w.shape[1:] == wgts[0].shape[1:] for w in wgts)
assert all(w.shape[1:] == wgts[0].shape[1:] for _, w in orig_wgts)
w_view_shapes = [infer_view_shape(w.shape, self.w_quantizer.config.largest_group_shape) for w in wgts]
if self.granularity != SearchBasedCalibGranularity.Layer:
_reshape_x = self._reshape_x_for_wgts_centric_partial_products
_reshape_w = self._reshape_w_for_wgts_centric_partial_products
else:
_reshape_x = self._reshape_x_for_full_products
_reshape_w = self._reshape_w_for_full_products
assert isinstance(ipts, TensorsCache), "ipts should not be None for ProductsError"
if orig_ipts is None:
orig_ipts = ipts
same_ipts = orig_ipts is ipts
orig_ipts = TensorsCache(
{
key: TensorCache(
[_reshape_x(x, view_shape=w_view_shapes[0], fn=ipt.reshape) for x in ipt.data],
**ipt.get_factory_kwargs(channels_dim=1, reshape=ReshapeFn()),
)
for key, ipt in orig_ipts.items()
},
)
orig_opts: dict[tuple[int, ...], torch.Tensor] = {}
for j, (_, w) in enumerate(orig_wgts):
w = _reshape_w(w, view_shape=w_view_shapes[j])
for s, ipt in enumerate(orig_ipts):
for i, x in enumerate(ipt.data):
x = x.to(device=w.device, non_blocking=True)
y = torch.matmul(x, w)
y = y.view(*y.shape[:-2], y.shape[-2] * y.shape[-1])
orig_opts[(i, s, j)] = y.to(device=self.opts_device or y.device, non_blocking=True)
if self.needs_to_pre_reshape_x_for_wgts:
if same_ipts:
ipts = orig_ipts
else:
ipts = TensorsCache(
{
key: TensorCache(
[_reshape_x(x, view_shape=w_view_shapes[0], fn=ipt.reshape) for x in ipt.data],
**ipt.get_factory_kwargs(channels_dim=1, reshape=ReshapeFn()),
)
for key, ipt in ipts.items()
}
)
del orig_wgts, orig_ipts, same_ipts
elif self.objective == SearchBasedCalibObjective.OutputsError:
w_view_shapes, _state_dict = [], []
if orig_wgts is not None:
_state_dict = [(p, p.data) for p, _ in orig_wgts]
for p, w in orig_wgts:
p.data = w.to(device=p.data.device)
if orig_ipts is None:
orig_ipts = ipts
assert isinstance(orig_ipts, TensorsCache), "orig_ipts should not be None for OutputsError"
orig_opts: dict[tuple[int, ...], torch.Tensor] = {}
for i in range(len(orig_ipts.front().data)):
ipt = orig_ipts.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
orig_opts[(i,)] = y.to(device=self.opts_device or y.device, non_blocking=True)
del ipt, y
for p, s in _state_dict:
p.data = s
del orig_wgts, orig_ipts, _state_dict
else:
raise ValueError(f"Unknown objective {self.objective}")
gc.collect()
torch.cuda.empty_cache()
self.logger.debug(f"+ finished calculating the original outputs, ram usage: {psutil.virtual_memory().percent}")
# endregion
while not self.is_done():
self.ask()
e: list[torch.Tensor] = []
# region Step 2: Calculate the errors
if self.objective == SearchBasedCalibObjective.TensorError:
assert isinstance(orig_wgts, (tuple, list))
for w, (_, orig_w), w_view_shape in zip(wgts, orig_wgts, w_view_shapes, strict=True):
e_w = self._process_w_in_xw(w).sub_(orig_w)
if self.granularity == SearchBasedCalibGranularity.Group:
e_w = e_w.view(w_view_shape).abs_().pow_(self.config.degree)
e_w = e_w.sum(dim=tuple(range(1, len(w_view_shape), 2))).view(w_view_shape[::2])
elif self.granularity == SearchBasedCalibGranularity.ChannelGroup:
e_w = e_w.view(*w_view_shape[:4], -1).abs_().pow_(self.config.degree)
e_w = e_w.sum(dim=(0, 1, 3, 4)).view(w_view_shape[2])
elif self.granularity == SearchBasedCalibGranularity.Layer:
e_w = e_w.abs_().pow_(self.config.degree).sum().view(-1)
else:
raise ValueError(f"Unknown granularity {self.granularity}")
e.append(e_w)
elif self.objective == SearchBasedCalibObjective.ProductsError:
e = [None] * len(wgts)
for j, w in enumerate(wgts):
w = _reshape_w(self._process_w_in_xw(w), view_shape=w_view_shapes[j])
for s, ipt in enumerate(ipts):
for i, x in enumerate(ipt.data):
x = x.to(device=w.device, non_blocking=True)
if not self.needs_to_pre_reshape_x_for_wgts:
x = self._process_x_in_xw(x, channels_dim=ipt.channels_dim)
x = _reshape_x(x, view_shape=w_view_shapes[j], fn=ipt.reshape)
y = torch.matmul(x, w)
y = y.view(*y.shape[:-2], y.shape[-2] * y.shape[-1])
y = y.sub_(orig_opts[(i, s, j)].to(device=w.device, non_blocking=True))
if self.granularity == SearchBasedCalibGranularity.Group:
y = y.to(self.develop_dtype).pow_(self.config.degree).sum(dim=-1)
elif self.granularity == SearchBasedCalibGranularity.ChannelGroup:
y = y.view(y.shape[0], y.shape[1], -1)
y = y.to(self.develop_dtype).pow_(self.config.degree).sum(dim=(0, 2))
elif self.granularity == SearchBasedCalibGranularity.Layer:
y = y.to(self.develop_dtype).pow_(self.config.degree).sum().view(-1)
else:
raise ValueError(f"Unknown granularity {self.granularity}")
if e[j] is None:
e[j] = y
else:
e[j].add_(y)
elif self.objective == SearchBasedCalibObjective.OutputsError:
self._process_wgts_centric_mod(wgts=wgts, mods=mods, **kwargs)
e = [None]
for i in range(len(ipts.front().data)):
ipt = ipts.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
y = (y - orig_opts[(i,)].to(device=y.device, non_blocking=True)).to(self.develop_dtype)
y = y.pow_(self.config.degree).sum().view(-1)
if e[0] is None:
e[0] = y
else:
e[0].add_(y)
del ipt, y
self._recover_mod()
else:
raise ValueError(f"Unknown objective {self.objective}")
# endregion
self.tell(e)
return self.get_best()
def _calibrate_ipts( # noqa: C901
self,
wgts: list[torch.Tensor | nn.Parameter],
ipts: TensorsCache,
eval_module: nn.Module | None,
mods: list[nn.Module] | None,
orig_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_ipts: TensorsCache | None,
eval_kwargs: dict[str, tp.Any],
**kwargs,
) -> tp.Any:
if orig_ipts is None:
orig_ipts = ipts
assert ipts.num_tensors == orig_ipts.num_tensors
assert all(
x.shape == orig_x.shape
for ipt, orig_ipt in zip(ipts, orig_ipts, strict=True)
for x, orig_x in zip(ipt.data, orig_ipt.data, strict=True)
)
# region Step 1: Calculate the outputs
if self.objective == SearchBasedCalibObjective.TensorError:
assert all(x.shape == ipt.data[0].shape for ipt in ipts for x in ipt.data)
orig_opts = None
x_view_shapes = [
infer_view_shape(
ipt.data[0].view(-1, *ipt.data[0].shape[ipt.channels_dim :]).shape,
self.x_quantizer.config.largest_group_shape,
skip_first_dim=True,
)
for ipt in ipts
]
del orig_wgts
elif self.objective == SearchBasedCalibObjective.ProductsError:
assert all(ipt.channels_dim == 1 for ipt in ipts)
assert all(ipt.channels_dim == 1 for ipt in orig_ipts)
assert all(x.shape[1:] == ipts.front().data[0].shape[1:] for ipt in ipts for x in ipt.data)
if orig_wgts is None:
orig_wgts = [(None, w.detach().data) for w in wgts]
assert len(orig_wgts) == len(wgts)
if self.granularity != SearchBasedCalibGranularity.Layer:
_reshape_x = self._reshape_x_for_ipts_centric_partial_products
_reshape_w = self._reshape_w_for_ipts_centric_partial_products
else:
_reshape_x = self._reshape_x_for_full_products
_reshape_w = self._reshape_w_for_full_products
x_view_shapes = [
infer_view_shape(ipt.data[0].shape, self.x_quantizer.config.largest_group_shape, skip_first_dim=True)
for ipt in ipts
]
orig_opts: dict[tuple[int, ...], torch.Tensor] = {}
for j, (_, w) in enumerate(orig_wgts):
w = _reshape_w(w, view_shape=x_view_shapes[0])
for s, ipt in enumerate(orig_ipts):
for i, x in enumerate(ipt.data):
x = x.to(device=w.device, non_blocking=True)
x = _reshape_x(x, view_shape=x_view_shapes[s], fn=ipt.reshape)
y = torch.matmul(x, w)
y = y.view(*y.shape[:-2], y.shape[-2] * y.shape[-1])
orig_opts[(i, s, j)] = y.to(device=self.opts_device or y.device, non_blocking=True)
if self.needs_to_pre_reshape_w_for_ipts:
for j, w in enumerate(wgts):
wgts[j] = _reshape_w(w, view_shape=x_view_shapes[0])
del orig_wgts, orig_ipts
elif self.objective == SearchBasedCalibObjective.OutputsError:
x_view_shapes, _state_dict = [], []
if orig_wgts is not None:
_state_dict = [(p, p.data) for p, _ in orig_wgts]
for p, w in orig_wgts:
p.data = w.to(device=p.data.device)
orig_opts: dict[tuple[int, ...], torch.Tensor] = {}
for i in range(len(orig_ipts.front().data)):
ipt = orig_ipts.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
orig_opts[(i,)] = y.to(device=self.opts_device or y.device, non_blocking=True)
del ipt, y
for p, s in _state_dict:
p.data = s
del orig_wgts, orig_ipts, _state_dict
else:
raise ValueError(f"Unknown objective {self.objective}")
gc.collect()
torch.cuda.empty_cache()
# endregion
while not self.is_done():
self.ask()
e: list[torch.Tensor] = []
# region Step 2: Calculate the outputs errors
if self.objective == SearchBasedCalibObjective.TensorError:
e = [None] * len(ipts)
for s, (ipt, x_view_shape) in enumerate(zip(ipts, x_view_shapes, strict=True)):
for x in ipt.data:
e_x = self._process_x_in_xw(x, channels_dim=ipt.channels_dim).sub_(x)
if self.granularity == SearchBasedCalibGranularity.Group:
e_x = e_x.view(x_view_shape).abs_().pow_(self.config.degree)
e_x = e_x.sum(dim=tuple(range(1, len(x_view_shape), 2)))
if self.granularity == SearchBasedCalibGranularity.ChannelGroup:
e_x = e_x.view(*x_view_shape[:4], -1).abs_().pow_(self.config.degree)
e_x = e_x.sum(dim=(0, 1, 3, 4)).view(x_view_shape[2])
elif self.granularity == SearchBasedCalibGranularity.Layer:
e_x = e_x.abs_().pow_(self.config.degree).sum().view(-1)
else:
raise ValueError(f"Unknown granularity {self.granularity}")
if e[s] is None:
e[s] = e_x
else:
e[s].add_(e_x)
elif self.objective == SearchBasedCalibObjective.ProductsError:
e = [None] * len(ipts)
for j, w in enumerate(wgts):
if not self.needs_to_pre_reshape_w_for_ipts:
w = self._process_w_in_xw(w)
w = _reshape_w(w, view_shape=x_view_shapes[0])
for s, ipt in enumerate(ipts):
for i, x in enumerate(ipt.data):
x = x.to(device=w.device, non_blocking=True)
x = self._process_x_in_xw(x, channels_dim=ipt.channels_dim)
x = _reshape_x(x, view_shape=x_view_shapes[s], fn=ipt.reshape)
y = torch.matmul(x, w)
y = y.view(*y.shape[:-2], y.shape[-2] * y.shape[-1])
y = y.sub_(orig_opts[(i, s, j)].to(device=w.device, non_blocking=True))
if self.granularity == SearchBasedCalibGranularity.Group:
y = y.to(self.develop_dtype).pow_(self.config.degree).sum(dim=-1)
elif self.granularity == SearchBasedCalibGranularity.ChannelGroup:
y = y.view(y.shape[0], y.shape[1], -1)
y = y.to(self.develop_dtype).pow_(self.config.degree).sum(dim=(0, 2))
elif self.granularity == SearchBasedCalibGranularity.Layer:
y = y.to(self.develop_dtype).pow_(self.config.degree).sum().view(-1)
else:
raise ValueError(f"Unknown granularity {self.granularity}")
if e[s] is None:
e[s] = y
else:
e[s].add_(y)
elif self.objective == SearchBasedCalibObjective.OutputsError:
self._process_ipts_centric_mod(wgts=wgts, mods=mods, **kwargs)
e = [None]
for i in range(len(ipts.front().data)):
ipt = ipts.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
y = (y - orig_opts[(i,)].to(device=y.device, non_blocking=True)).to(self.develop_dtype)
y = y.pow_(self.config.degree).sum().view(-1)
if e[0] is None:
e[0] = y
else:
e[0].add_(y)
del ipt, y
self._recover_mod()
else:
raise ValueError(f"Unknown objective {self.objective}")
# endregion
self.tell(e)
return self.get_best()
def _calibrate_opts( # noqa: C901
self,
x_wgts: list[torch.Tensor | nn.Parameter],
y_wgts: list[torch.Tensor | nn.Parameter],
eval_inputs: TensorsCache | None,
eval_module: nn.Module | None,
x_mods: list[nn.Module] | None,
y_mods: list[nn.Module] | None,
orig_x_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_y_wgts: list[tuple[nn.Parameter, torch.Tensor]] | None,
orig_eval_inputs: TensorsCache | None,
eval_kwargs: dict[str, tp.Any],
**kwargs,
) -> tp.Any:
# region Step 1: Calculate the outputs
if self.objective == SearchBasedCalibObjective.OutputsError:
assert eval_inputs is not None, "eval_inputs should not be None when objective is OutputsError"
if orig_eval_inputs is None:
orig_eval_inputs = eval_inputs
assert eval_inputs.num_tensors == orig_eval_inputs.num_tensors
assert all(
x.shape == orig_x.shape
for key, ipt in eval_inputs.items()
for x, orig_x in zip(ipt.data, orig_eval_inputs[key].data, strict=True)
)
_x_state_dict, _y_state_dict = [], []
if orig_x_wgts is not None:
_x_state_dict = [(p, p.data) for p, _ in orig_x_wgts]
for p, w in orig_x_wgts:
p.data = w.to(device=p.data.device)
if orig_y_wgts is not None:
_y_state_dict = [(p, p.data) for p, _ in orig_y_wgts]
for p, w in orig_y_wgts:
p.data = w.to(device=p.data.device)
orig_opts: dict[tuple[int, ...], torch.Tensor] = {}
for i in range(len(orig_eval_inputs.front().data)):
ipt = orig_eval_inputs.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
orig_opts[(i,)] = y.to(device=self.opts_device or y.device, non_blocking=True)
del ipt, y
for p, s in _x_state_dict:
p.data = s
for p, s in _y_state_dict:
p.data = s
del orig_x_wgts, orig_y_wgts, orig_eval_inputs, _x_state_dict, _y_state_dict
else:
raise ValueError(f"Unknown objective {self.objective}")
gc.collect()
torch.cuda.empty_cache()
# endregion
while not self.is_done():
self.ask()
e: list[torch.Tensor] = []
# region Step 2: Calculate the outputs errors
if self.objective == SearchBasedCalibObjective.OutputsError:
self._process_opts_centric_mod(
x_wgts=x_wgts,
y_wgts=y_wgts,
x_mods=x_mods,
y_mods=y_mods,
**kwargs,
)
e = [None]
for i in range(len(eval_inputs.front().data)):
ipt = eval_inputs.extract(i, eval_kwargs)
y = eval_module(*ipt.args, **ipt.kwargs)
y = y[0] if not isinstance(y, torch.Tensor) else y
assert isinstance(y, torch.Tensor), "eval_mod should return a tensor"
y = (y - orig_opts[(i,)].to(device=y.device, non_blocking=True)).to(self.develop_dtype)
y = y.pow_(self.config.degree).sum().view(-1)
if e[0] is None:
e[0] = y
else:
e[0].add_(y)
del ipt, y
self._recover_mod()
else:
raise ValueError(f"Unknown objective {self.objective}")
# endregion
self.tell(e)
return self.get_best()
================================================
FILE: deepcompressor/calib/smooth.py
================================================
# -*- coding: utf-8 -*-
"""Smooth quantization module."""
import gc
import typing as tp
from dataclasses import _MISSING_TYPE, MISSING, dataclass
import torch
import torch.nn as nn
from ..data.cache import TensorsCache
from ..data.common import TensorType
from ..quantizer.processor import Quantizer
from ..utils import math, tools
from ..utils.common import split_sequence
from ..utils.hooks import BaseInputPackager, BaseOutputPackager, BaseTensorProcessor
from .config import SearchBasedCalibObjective, SmoothCalibConfig, SmoothSpanMode
from .metric import ChannelMetric
from .search import SearchBasedCalibrator
__all__ = [
"smooth_linear_modules",
"smooth_attention",
"convert_smooth_upscale_to_downscale",
"ActivationSmoother",
"get_smooth_scale",
"get_smooth_span",
"SmoothCalibrator",
"SmoothLinearCalibrator",
"SmoothAttentionCalibrator",
]
@dataclass
class ActivationSmoother(BaseTensorProcessor):
"""The quantization smoothing processor."""
smooth_scale: torch.Tensor
channels_dim: int
upscale: bool = False
develop_dtype: torch.dtype | None = None
# region hook-related attributes
input_packager: BaseInputPackager | None = None
output_packager: BaseOutputPackager | None = None
# endregion
def is_enabled(self) -> bool:
return self.smooth_scale is not None
def get_input_packager(self) -> BaseInputPackager | None:
return self.input_packager
def get_output_packager(self) -> BaseOutputPackager | None:
return self.output_packager
def process(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process the tensor.
Args:
tensor (`torch.Tensor`):
The tensor to smooth.
Returns:
`torch.Tensor`:
The smoothed tensor.
"""
device, dtype = tensor.device, tensor.dtype
if self.develop_dtype is None:
self.develop_dtype = dtype
self.smooth_scale = self.smooth_scale.to(device=device, dtype=self.develop_dtype)
tensor = tensor.to(dtype=self.develop_dtype)
smooth_scale_view_shape = [1] * tensor.ndim
smooth_scale_view_shape[self.channels_dim] = -1
smooth_scale = self.smooth_scale.view(smooth_scale_view_shape)
if self.upscale:
return tensor.mul(smooth_scale).to(dtype=dtype)
else:
return tensor.div(smooth_scale).to(dtype=dtype)
@torch.inference_mode()
def get_smooth_span(
tensors: tp.Sequence[torch.Tensor],
/,
*,
group_shape: tp.Sequence[int],
span_mode: SmoothSpanMode,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Calculate the value span of tensors for calculating smoothing scale.
Args:
tensors (`Sequence[torch.Tensor]`):
Tensors to calculate the span.
group_shape (`Sequence[int]`):
Quantization group shape.
span_mode (`SmoothSpanMode`):
The quantization smoothing span mode.
device (`torch.device` or `str` or `None`, *optional*, defaults to `None`):
Device to store the span.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the span.
Returns:
`torch.Tensor`:
The span of the tensors for calculating smoothing scale.
"""
# convert span mode name from camel case to snake case
range_name = "".join(["_" + c.lower() if c.isupper() else c for c in span_mode.name]).lstrip("_")
range_fn = getattr(ChannelMetric, range_name)
r: torch.Tensor = range_fn(tensors, tensors[0].shape[1], group_shape, device=device, dtype=dtype)
return r
@torch.inference_mode()
def get_smooth_scale(*, alpha_base: torch.Tensor, beta_base: torch.Tensor, alpha: float, beta: float) -> torch.Tensor:
"""Calculate the smoothing scale for quantization. Scale = alpha_base^alpha / beta_base^beta.
Args:
alpha_base (`torch.Tensor`):
Base span for alpha.
beta_base (`torch.Tensor`):
Base span for beta.
alpha (`float`):
Alpha.
beta (`float`):
Beta.
Returns:
`torch.Tensor`:
Smoothing scale.
"""
assert 0 <= alpha <= 1 and 0 <= beta <= 1, "The smooth factors should be in [0, 1]."
if alpha > 0:
smooth_scale = alpha_base.pow(alpha)
if beta > 0:
smooth_scale = smooth_scale.div_(beta_base.pow(beta))
else:
smooth_scale = beta_base.pow(-beta)
smooth_scale[smooth_scale == 0] = 1
if smooth_scale.isnan().any() or smooth_scale.isinf().any():
smooth_scale = smooth_scale.fill_(1)
assert not smooth_scale.isnan().any(), "The smooth scale contains NaN."
assert not smooth_scale.isinf().any(), "The smooth scale contains Inf."
return smooth_scale
class SmoothCalibrator(SearchBasedCalibrator[SmoothCalibConfig, torch.Tensor]):
"""The quantization smoothing calibrator."""
def __init__(
self,
tensor_type: TensorType,
config: SmoothCalibConfig,
w_quantizer: Quantizer | None,
x_quantizer: Quantizer | None,
y_quantizer: Quantizer | None,
num_heads: int = 1,
num_head_repeats: int = 1,
with_rope: bool = False,
develop_dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the calibrator.
Args:
tensor_type (`TensorType`):
The type of tensor to quantize. Choices are ``Weights`` and ``Outputs``.
config (`SmoothCalibConfig`):
The quantization smoothing calibration configuration.
w_quantizer (`Quantizer` or `None`):
The w quantizer for x-w computation.
x_quantizer (`Quantizer` or `None`):
The x quantizer for x-w or y-x computation.
y_quantizer (`Quantizer` or `None`):
The y quantizer for y-x computation.
num_heads (`int`, *optional*, defaults to ``1``):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to ``1``):
The number of head repeats.
with_rope (`bool`, *optional*, defaults to ``False``):
Whether rotary position embedding is used for y-x computation.
develop_dtype (torch.dtype, *optional*, defaults to ``torch.float32``):
The development data type.
"""
assert tensor_type in (TensorType.Weights, TensorType.Outputs)
super().__init__(
tensor_type=tensor_type,
config=config,
w_quantizer=w_quantizer,
x_quantizer=x_quantizer,
y_quantizer=y_quantizer,
develop_dtype=develop_dtype,
)
self.num_heads = num_heads
self.num_head_repeats = num_head_repeats
self.with_rope = self.tensor_type != TensorType.Weights and with_rope
# region set group shapes of weights, inputs and outputs
if self.needs_w_quant:
w_group_shape = list(self.w_quantizer.config.largest_group_shape)
else:
w_group_shape = [1, None, -1]
if self.needs_x_quant:
x_group_shape = list(self.x_quantizer.config.largest_group_shape)
else:
x_group_shape = [1, None, -1]
if self.needs_y_quant:
y_group_shape = list(self.y_quantizer.config.largest_group_shape)
else:
y_group_shape = [1, None, -1]
w_group_shape[1] = x_group_shape[1] if w_group_shape[1] is None else w_group_shape[1]
if self.tensor_type == TensorType.Weights:
x_group_shape[1] = w_group_shape[1] if x_group_shape[1] is None else x_group_shape[1]
else:
x_group_shape[1] = y_group_shape[1] if x_group_shape[1] is None else x_group_shape[1]
y_group_shape[1] = x_group_shape[1] if y_group_shape[1] is None else y_group_shape[1]
self.w_group_shape, self.x_group_shape, self.y_group_shape = w_group_shape, x_group_shape, y_group_shape
# endregion
self.alpha_beta_pairs = self.config.get_alpha_beta_pairs()
self.num_iters = 1
@property
def population_size(self) -> int:
"""Get the population size."""
return len(self.alpha_beta_pairs) * len(self.span_mode_pairs)
@property
def allows_x_quant_for_wgts(self) -> bool:
"""Whether the calibrator allows input quantization when tensor_type is Weights."""
return self.config.allow_a_quant
@property
def allows_w_quant_for_wgts(self) -> bool:
"""Whether the calibrator needs weight quantization when tensor_type is Weights."""
return self.config.allow_b_quant
@property
def allows_w_quant_for_ipts(self) -> bool:
"""Whether the calibrator allows weight quantization when tensor_type is Inputs."""
return self.config.allow_b_quant
@property
def allows_x_quant_for_opts(self) -> bool:
"""Whether the calibrator allows x quantization when tensor_type is Outputs."""
return self.config.allow_b_quant
@property
def allows_y_quant_for_opts(self) -> bool:
"""Whether the calibrator allows y quantization when tensor_type is Outputs."""
return self.config.allow_a_quant
@property
def allows_w_quant_for_opts(self) -> bool:
"""Whether the calibrator allows weight quantization when tensor_type is Outputs."""
return False
@property
def span_mode_pairs(self) -> list[tuple[SmoothSpanMode, SmoothSpanMode]]:
"""Get the span modes."""
return self.config.spans
@property
def alpha_span_modes(self) -> list[SmoothSpanMode]:
"""Get the span modes for alpha."""
return self.config.a_spans
@property
def beta_span_modes(self) -> list[SmoothSpanMode]:
"""Get the span modes for beta."""
return self.config.b_spans
def _reset( # noqa: C901
self,
*,
x_wgts: list[torch.Tensor | nn.Parameter],
x_acts: TensorsCache,
y_wgts: list[torch.Tensor | nn.Parameter] = None,
y_acts: TensorsCache | None = None,
**kwargs,
) -> None:
"""Reset the calibrator.
Args:
x_wgts (`list[torch.Tensor | nn.Parameter]`):
The weights in x-w computation, or weights that generates x for y-x computation.
x_acts (`TensorsCache`):
The x activations. It should be x for x-w or y-x computation.
y_wgts (`list[torch.Tensor | nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights that generates y for y-x computation.
y_acts (`TensorsCache` or `None`, *optional*, defaults to `None`):
The y activations. It should be y for y-x computation.
"""
wgts_centric = self.tensor_type == TensorType.Weights
self.num_in_channels = x_wgts[0].shape[1] if wgts_centric else x_wgts[0].shape[0]
device = x_wgts[0].device
if self.num_heads > 1 and self.num_head_repeats > 1:
self.num_unique_heads = self.num_heads // self.num_head_repeats
else:
self.num_unique_heads = 0
# region get x spans
assert (
x_acts.num_tensors == 1
), f"Only one input is allowed, got {x_acts.num_tensors}=len({list(x_acts.keys())})"
x_tensors = x_acts.front().get_standardized_data(reshape=False)
assert all(x.shape[1] == self.num_in_channels for x in x_tensors)
x_spans = {}
for span_mode in self.alpha_span_modes if wgts_centric else self.beta_span_modes:
x_span = get_smooth_span(
x_tensors,
group_shape=self.x_group_shape,
span_mode=span_mode,
device=device,
dtype=self.develop_dtype,
)
if self.num_unique_heads > 0:
x_span = x_span.view(self.num_unique_heads, self.num_head_repeats, -1)
x_span = (x_span.amax if "Max" in span_mode.name else x_span.mean)(dim=1, keepdim=True)
x_span = x_span.expand(self.num_unique_heads, self.num_head_repeats, -1).reshape(-1)
if self.tensor_type == TensorType.Outputs and self.with_rope:
x_span = x_span.view(self.num_heads, 2, -1)
x_span = (x_span.amax if "Max" in span_mode.name else x_span.mean)(dim=1, keepdim=True)
x_span = x_span.expand(self.num_heads, 2, -1).reshape(-1)
x_spans[span_mode] = x_span
if self.logger.level <= tools.logging.DEBUG:
self.logger.debug("+ x - %s", span_mode.name)
self.logger.debug("+ x = [min=%.4f, max=%.4f]", x_span.min().item(), x_span.max().item())
del x_tensors
# endregion
if wgts_centric:
assert all(w.shape[1] == self.num_in_channels for w in x_wgts)
w_tensors = [w.data for w in x_wgts]
w_spans = {}
for span_mode in self.beta_span_modes:
w_span = get_smooth_span(
w_tensors,
group_shape=self.w_group_shape,
span_mode=span_mode,
dtype=self.develop_dtype,
)
if self.num_unique_heads > 0:
w_span = w_span.view(self.num_unique_heads, self.num_head_repeats, -1)
w_span = (w_span.amax if "Max" in span_mode.name else w_span.mean)(dim=1, keepdim=True)
w_span = w_span.expand(self.num_unique_heads, self.num_head_repeats, -1).reshape(-1)
w_spans[span_mode] = w_span
if self.logger.level <= tools.logging.DEBUG:
self.logger.debug("+ w - %s", span_mode.name)
self.logger.debug("+ w = [min=%.4f, max=%.4f]", w_span.min().item(), w_span.max().item())
self.span_pairs: list[tuple[torch.Tensor, torch.Tensor]] = [
(x_spans[x_span_mode], w_spans[w_span_mode]) for x_span_mode, w_span_mode in self.span_mode_pairs
]
else:
assert y_acts.num_tensors == 1, f"Only one output source is allowed, got {y_acts.num_tensors}"
if self.num_unique_heads > 0:
num_out_channels = self.num_in_channels // self.num_head_repeats
else:
num_out_channels = self.num_in_channels
assert all(w.shape[0] == self.num_in_channels for w in x_wgts)
assert all(w.shape[0] == num_out_channels for w in y_wgts)
y_tensors = y_acts.front().get_standardized_data(reshape=False)
assert all(y.shape[1] == num_out_channels for y in y_tensors)
y_spans = {}
for span_mode in self.alpha_span_modes:
y_span = get_smooth_span(
y_tensors,
group_shape=self.x_group_shape,
span_mode=span_mode,
device=device,
dtype=self.develop_dtype,
)
if self.num_unique_heads > 0:
y_span = y_span.view(self.num_unique_heads, 1, -1)
y_span = y_span.expand(self.num_unique_heads, self.num_head_repeats, -1).reshape(-1)
if self.tensor_type == TensorType.Outputs and self.with_rope:
y_span = y_span.view(self.num_heads, 2, -1)
y_span = (y_span.amax if "Max" in span_mode.name else y_span.mean)(dim=1, keepdim=True)
y_span = y_span.expand(self.num_heads, 2, -1).reshape(-1)
y_spans[span_mode] = y_span
if self.logger.level <= tools.logging.DEBUG:
self.logger.debug("+ y - %s", span_mode.name)
self.logger.debug("+ y = [min=%.4f, max=%.4f]", y_span.min().item(), y_span.max().item())
self.span_pairs: list[tuple[torch.Tensor, torch.Tensor]] = [
(y_spans[y_span_mode], x_spans[x_span_mode]) for y_span_mode, x_span_mode in self.span_mode_pairs
]
self.best_error: list[torch.Tensor] = None
self.best_scale: torch.Tensor = None
self.error_history: list[tuple[float, float]] = []
def _split_candidate_id(self, candidate_id: int) -> tuple[int, int]:
"""Split the candidate id into alpha_beta id and span_pair id.
Args:
candidate_id (`int`):
The candidate id.
Returns:
`tuple[int, int]`:
The alpha_beta id and span_mode id.
"""
alpha_beta_id = candidate_id % len(self.alpha_beta_pairs)
span_pair_id = candidate_id // len(self.alpha_beta_pairs)
return alpha_beta_id, span_pair_id
def get_best(self) -> torch.Tensor:
"""Get the best candidate.
Returns:
`torch.Tensor`:
The best candidate.
"""
return self.best_scale
def _ask(self) -> torch.Tensor:
"""Ask for the next candidate.
Returns:
`torch.Tensor`:
The next candidate.
"""
alpha_beta_id, span_pair_id = self._split_candidate_id(self.candidate_id)
alpha, beta = self.alpha_beta_pairs[alpha_beta_id]
a_span, b_span = self.span_pairs[span_pair_id]
if alpha == 0 and beta == 0:
scale = torch.ones_like(a_span, dtype=self.develop_dtype)
else:
scale = get_smooth_scale(alpha_base=a_span, beta_base=b_span, alpha=alpha, beta=beta)
return scale
def _tell(self, error: list[torch.Tensor]) -> None: # noqa: C901
"""Tell the error of the last candidate and update the best candidate.
Args:
error (`list[torch.Tensor]`):
The error of the last candidate.
"""
numel = error[0].numel()
assert all(e.numel() == numel for e in error)
scale = self.candidate
self.best_error, self.best_scale = self._update_best(
best_error=self.best_error,
best_scale=self.best_scale,
error=error,
scale=scale,
numel=numel,
num_channels=self.num_in_channels,
num_heads=self.num_heads,
num_head_repeats=self.num_head_repeats,
)
if self.logger.level <= tools.logging.DEBUG:
self.error_history.append(
(
sum(math.root_(e.to(torch.float64).sum(), self.config.degree).item() for e in error),
sum(math.root_(b.to(torch.float64).sum(), self.config.degree).item() for b in self.best_error),
)
)
if self.is_last_candidate_in_iter():
logs: list[list[list[tuple]]] = [[] for _ in range(len(self.span_mode_pairs))]
for i in range(self.population_size):
c, r = self._split_candidate_id(i)
alpha, beta = self.alpha_beta_pairs[c]
if c % 5 == 0:
logs[r].append([])
logs[r][-1].append((alpha, beta, self.error_history[i][0], self.error_history[i][1]))
for r in range(len(self.span_mode_pairs)):
self.logger.debug(
" - x / w range = %s / %s", self.span_mode_pairs[r][0].name, self.span_mode_pairs[r][1].name
)
for log in logs[r]:
self.logger.debug(
" - alpha = [%s]",
", ".join(f"{alpha:10.4f}" for alpha, beta, e, b in log),
)
self.logger.debug(
" - beta = [%s]",
", ".join(f"{beta:10.4f}" for alpha, beta, e, b in log),
)
self.logger.debug(
" - sum error = [%s]", ", ".join(f"{e:10.4f}" for alpha, beta, e, b in log)
)
self.logger.debug(
" - best error = [%s]",
", ".join(f"{b:10.4f}" for alpha, beta, e, b in log),
)
del logs
self.error_history.clear()
if self.is_last_iter():
scale = self.get_best()
tools.logging.Formatter.indent_dec()
self.logger.debug(
" + error = %.4f",
sum(math.root_(b.to(torch.float64).sum(), self.config.degree).item() for b in self.best_error),
)
self.logger.debug(" + scale = [min=%.4f, max=%.4f]", scale.min().item(), scale.max().item())
tools.logging.Formatter.indent_inc()
def _reshape_scale(
self, scale: torch.Tensor, tensor: torch.Tensor, channels_dim: int, needs_reduction: bool = False
) -> torch.Tensor:
if self.num_unique_heads > 0 and needs_reduction:
scale = scale.view(self.num_unique_heads, self.num_head_repeats, -1)[:, 0, :].reshape(-1)
shape = [1] * tensor.ndim
shape[channels_dim] = -1
return scale.view(shape)
def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if not self.needs_x_quant_for_wgts:
return x
if channels_dim is MISSING:
channels_dim = self.x_quantizer.channels_dim
shape, dtype = x.shape, x.dtype
scale = self._reshape_scale(self.candidate, x, channels_dim)
x = x.to(dtype=self.develop_dtype) if dtype != self.develop_dtype else x.clone()
x = x.div_(scale)
x = self.x_quantizer.quantize(
x, channels_dim=channels_dim, default_dtype=dtype, develop_dtype=self.develop_dtype
).data
x = x.mul_(scale).to(dtype=dtype)
return x.view(shape)
def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
if not self.needs_w_quant_for_wgts:
return w
dtype = w.dtype
channels_dim = 1 if self.w_quantizer.channels_dim is None else self.w_quantizer.channels_dim
scale = self._reshape_scale(self.candidate, w, channels_dim=channels_dim)
w = w.to(dtype=self.develop_dtype) if dtype != self.develop_dtype else w.clone()
w = self.w_quantizer.quantize(
w.mul_(scale), kernel=None, default_dtype=dtype, develop_dtype=self.develop_dtype
).data
w = w.div_(scale).to(dtype=dtype)
return w
def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if not self.needs_x_quant_for_opts:
return x
shape, dtype = x.shape, x.dtype
if self.objective != SearchBasedCalibObjective.OutputsError:
if channels_dim is MISSING:
channels_dim = self.x_quantizer.channels_dim
scale = self._reshape_scale(self.candidate, x, channels_dim, needs_reduction=False)
x = x.to(dtype=self.develop_dtype) if dtype != self.develop_dtype else x.clone()
x = x.mul_(scale)
# ! `x` is already scaled during `_process_opts_centric_mod` by scaling `xw`
x = self.x_quantizer.quantize(
x,
channels_dim=channels_dim,
default_dtype=dtype,
develop_dtype=self.develop_dtype,
).data
if self.objective != SearchBasedCalibObjective.OutputsError:
x = x.div_(scale).to(dtype=dtype)
return x.view(shape)
def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSING_TYPE = MISSING) -> torch.Tensor:
if not self.needs_y_quant_for_opts:
return y
shape, dtype = y.shape, y.dtype
if self.objective != SearchBasedCalibObjective.OutputsError:
if channels_dim is MISSING:
channels_dim = self.x_quantizer.channels_dim
scale = self._reshape_scale(self.candidate, y, channels_dim, needs_reduction=True)
y = y.to(dtype=self.develop_dtype) if dtype != self.develop_dtype else y.clone()
y = y.div_(scale)
# ! `y` is already scaled during `_process_opts_centric_mod` by scaling `yw`
y = self.y_quantizer.quantize(
y,
channels_dim=channels_dim,
default_dtype=dtype,
develop_dtype=self.develop_dtype,
).data
if self.objective != SearchBasedCalibObjective.OutputsError:
y = y.mul_(scale).to(dtype=dtype)
return y.view(shape)
def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("The method `_process_xw_in_yx` should not be called in SmoothCalibrator.")
def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
raise RuntimeError("The method `_process_yw_in_yx` should not be called in SmoothCalibrator.")
def _process_wgts_centric_mod(
self,
wgts: list[nn.Parameter],
mods: list[nn.Module],
update_state_dict: bool = True,
splits: list[int] | None = None,
**kwargs,
) -> None:
if self.needs_w_quant_for_wgts and self.config.allow_low_rank and self.w_quantizer.is_enabled_low_rank():
assert len(wgts) == len(mods)
for wgt in wgts:
if update_state_dict:
self._state_dict.append((wgt, wgt.data))
dtype = wgt.dtype
scale = self._reshape_scale(self.candidate, wgt.data, channels_dim=1)
wgt.data = wgt.data.to(dtype=self.develop_dtype).mul(scale).to(dtype=dtype)
input_packager = self.x_quantizer.get_input_packager() if self.needs_x_quant else None
for mod in mods:
self._hooks.append(
ActivationSmoother(
self.candidate,
self.x_quantizer.channels_dim,
develop_dtype=self.develop_dtype,
input_packager=input_packager,
)
.as_hook()
.register(mod)
)
if splits:
wgts_splits: list[list[nn.Parameter]] = split_sequence(wgts, splits)
mods_splits: list[list[nn.Module]] = split_sequence(mods, splits)
else:
wgts_splits, mods_splits = [wgts], [mods]
for wgts_split, mods_split in zip(wgts_splits, mods_splits, strict=True):
for qwgt, lowr, wgt, mod in zip(
*self.w_quantizer.quantize_with_low_rank(wgts_split, kernel=None, develop_dtype=self.develop_dtype),
wgts_split,
mods_split,
strict=True,
):
wgt.data = qwgt.data
self._hooks.append(lowr.as_hook(input_packager=input_packager).register(mod))
if self.needs_x_quant_for_wgts:
self._hooks.append(self.x_quantizer.as_hook().register(mod))
else:
super()._process_wgts_centric_mod(wgts=wgts, mods=mods, update_state_dict=update_state_dict, **kwargs)
def _process_opts_centric_mod(
self,
x_wgts: list[nn.Parameter],
y_wgts: list[nn.Parameter],
x_mods: list[nn.Module],
y_mods: list[nn.Module],
update_state_dict: bool = True,
**kwargs,
) -> None:
for w in x_wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
scale = self._reshape_scale(self.candidate, w, channels_dim=0, needs_reduction=False)
w.data = w.detach().data.to(dtype=self.develop_dtype).mul(scale).to(dtype=w.dtype)
for w in y_wgts:
if update_state_dict:
self._state_dict.append((w, w.data))
scale = self._reshape_scale(self.candidate, w, channels_dim=0, needs_reduction=True)
w.data = w.detach().data.to(dtype=self.develop_dtype).div(scale).to(dtype=w.dtype)
super()._process_opts_centric_mod(
x_wgts=x_wgts,
y_wgts=y_wgts,
x_mods=x_mods,
y_mods=y_mods,
update_state_dict=False,
**kwargs,
)
@staticmethod
def _update_best(
*,
best_error: list[torch.Tensor] | None,
best_scale: torch.Tensor,
error: list[torch.Tensor],
scale: torch.Tensor,
numel: int,
num_channels: int,
num_heads: int,
num_head_repeats: int,
) -> tuple[list[torch.Tensor], torch.Tensor]:
if best_error is None:
return error, scale
elif numel == 1: # tensor wise quantization error
if all(e <= b for b, e in zip(best_error, error, strict=True)):
return error, scale
return best_error, best_scale
else: # channel group wise quantization error
assert num_channels % numel == 0
group_size, num_groups = num_channels // numel, numel
needs_reduction = num_heads > 1 and num_head_repeats > 1
if needs_reduction:
num_head_channels = num_channels // num_heads
num_unique_heads = num_heads // num_head_repeats
if num_head_channels >= group_size:
assert num_head_channels % group_size == 0
num_groups_per_head = num_head_channels // group_size
num_repeats = num_head_repeats
num_unqiue_heads_per_group = 1
else:
assert group_size % num_head_channels == 0
num_heads_per_group = group_size // num_head_channels
if num_heads_per_group < num_head_repeats:
assert num_head_repeats % num_heads_per_group == 0
num_groups_per_head = 1
num_repeats = num_head_repeats // num_heads_per_group
num_unqiue_heads_per_group = 1
else:
assert num_heads_per_group % num_head_repeats == 0
num_groups_per_head = 1
num_repeats = 1
num_unqiue_heads_per_group = num_heads_per_group // num_head_repeats
num_uniques = num_unique_heads // num_unqiue_heads_per_group
needs_reduction = needs_reduction and num_repeats > 1
pos = torch.full((numel,), True, device=error[0][0].device)
for b, e in zip(best_error, error, strict=True):
if needs_reduction:
b = b.view(num_uniques, num_repeats, num_groups_per_head).sum(dim=1, keepdim=True)
e = e.view(num_uniques, num_repeats, num_groups_per_head).sum(dim=1, keepdim=True)
pos = pos & (e < b).expand(num_uniques, num_repeats, num_groups_per_head).reshape_as(pos)
else:
pos = pos & (e < b)
for b, e in zip(best_error, error, strict=True):
b[pos] = e[pos]
pos = pos.view(num_groups, 1).expand(num_groups, group_size)
best_scale = best_scale.view(num_groups, group_size)
best_scale[pos] = scale.view(num_groups, group_size)[pos]
return best_error, best_scale
class SmoothLinearCalibrator(SmoothCalibrator):
"""The smooth quantization calibrator for linear module."""
def __init__(
self,
config: SmoothCalibConfig,
weight_quantizer: Quantizer | None,
input_quantizer: Quantizer | None,
num_heads: int = 1,
num_head_repeats: int = 1,
develop_dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the calibrator.
Args:
config (`SmoothCalibConfig`):
The quantization smoothing calibration configuration.
weight_quantizer (`Quantizer` or `None`):
The weight quantizer.
input_quantizer (`Quantizer` or `None`):
The input quantizer.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
"""
super().__init__(
tensor_type=TensorType.Weights,
config=config,
w_quantizer=weight_quantizer,
x_quantizer=input_quantizer,
y_quantizer=None,
num_heads=num_heads,
num_head_repeats=num_head_repeats,
develop_dtype=develop_dtype,
)
class SmoothAttentionCalibrator(SmoothCalibrator):
"""The smooth quantization calibrator for attention module."""
def __init__(
self,
config: SmoothCalibConfig,
query_quantizer: Quantizer | None,
key_quantizer: Quantizer | None,
num_heads: int = 1,
num_head_repeats: int = 1,
with_rope: bool = True,
develop_dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the calibrator.
Args:
config (`SmoothCalibConfig`):
The quantization smoothing calibration configuration.
query_quantizer (`Quantizer` or `None`):
The query quantizer.
key_quantizer (`Quantizer` or `None`):
The key quantizer.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
with_rope (`bool`, *optional*, defaults to `True`):
Whether rotary position embedding is used.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
"""
super().__init__(
tensor_type=TensorType.Outputs,
config=config,
w_quantizer=None,
x_quantizer=query_quantizer,
y_quantizer=key_quantizer,
num_heads=num_heads,
num_head_repeats=num_head_repeats,
with_rope=with_rope,
develop_dtype=develop_dtype,
)
def calibrate(
self,
q_proj_weight: nn.Parameter,
k_proj_weight: nn.Parameter,
queries: TensorsCache,
keys: TensorsCache,
query_module: nn.Module,
key_module: nn.Module,
eval_module: nn.Module | None = None,
eval_inputs: TensorsCache | None = None,
eval_kwargs: dict[str, tp.Any] | None = None,
) -> tp.Any:
"""Calibrate the quantization for attention.
Args:
q_proj_weight (`nn.Parameter`):
The query projection weight.
k_proj_weight (`nn.Parameter`):
The key projection weight.
queries (`TensorsCache`):
The query activations.
keys (`TensorsCache`):
The key activations.
query_module (`nn.Module`):
The module that generates the query activations,
e.g., either `q_proj` for pre-rope or `q_rotary_emb` for post-rope.
key_module (`nn.Module`):
The module that generates the key activations,
e.g., either `k_proj` for pre-rope or `k_rotary_emb` for post-rope.
eval_module (`nn.Module`, *optional*):
The evaluation module.
eval_inputs (`TensorsCache`, *optional*):
The evaluation inputs.
eval_kwargs (`dict[str, tp.Any]`, *optional*):
The evaluation keyword arguments.
Returns:
tp.Any: The evaluation result.
"""
return super().calibrate(
x_wgts=[q_proj_weight],
y_wgts=[k_proj_weight],
x_acts=queries,
y_acts=keys,
x_mods=[query_module],
y_mods=[key_module],
eval_module=eval_module,
eval_inputs=eval_inputs,
eval_kwargs=eval_kwargs,
)
def smooth_upscale_param(param: nn.Parameter, scale: torch.Tensor, channels_dim: int = 1) -> None:
"""In-place smooth the parameter by upscaling.
Args:
param (`nn.Parameter`):
The parameter to smooth.
scale (`torch.Tensor`):
The scale to upscale.
channels_dim (`int`, *optional*, defaults to `1`):
The dimension of channels.
"""
dtype = param.dtype
view_shape = [1] * param.ndim
view_shape[channels_dim] = -1
scale = scale.to(device=param.device).view(view_shape)
param.data = param.data.to(dtype=scale.dtype).mul_(scale).to(dtype=dtype)
assert not param.data.isnan().any(), "NaN found in param when smoothing"
assert not param.data.isinf().any(), "Inf found in param when smoothing"
def smooth_downscale_param(param: nn.Parameter, scale: torch.Tensor, channels_dim: int = 0) -> None:
"""In-place smooth the parameter by downscaling.
Args:
param (`nn.Parameter`):
The parameter to smooth.
scale (`torch.Tensor`):
The scale to downscale.
channels_dim (`int`, *optional*, defaults to `0`):
The dimension of channels.
"""
dtype = param.dtype
view_shape = [1] * param.ndim
view_shape[channels_dim] = -1
scale = scale.to(device=param.device).view(view_shape)
param_data = param.data.to(dtype=scale.dtype)
param_data.narrow(channels_dim, 0, scale.numel()).div_(scale)
param.data = param_data.to(dtype=dtype)
assert not param.data.isnan().any(), "NaN found in param when smoothing"
assert not param.data.isinf().any(), "Inf found in param when smoothing"
def convert_smooth_upscale_to_downscale(
scale: torch.Tensor, num_heads: int = 1, num_head_repeats: int = 1
) -> torch.Tensor:
"""Convert the upscale smooth scale to downscale smooth scale.
Args:
scale (`torch.Tensor`):
The upscale smooth scale.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
Returns:
`torch.Tensor`:
The downscale smooth scale.
"""
if num_heads > 1 and num_head_repeats > 1:
head_channels = scale.numel() // num_heads
num_unique_heads = num_heads // num_head_repeats
return scale.view(num_unique_heads, num_head_repeats, head_channels)[:, 0, :].reshape(-1)
else:
return scale
@torch.inference_mode()
def smooth_linear_modules(
prevs: nn.Module | tp.Sequence[nn.Module] | None,
modules: tp.Sequence[nn.Linear] | nn.Linear,
*,
scale: torch.Tensor | None,
config: SmoothCalibConfig | None = None,
weight_quantizer: Quantizer | None = None,
input_quantizer: Quantizer | None = None,
weights: list[nn.Parameter] | None = None,
inputs: TensorsCache | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module = None,
eval_kwargs: dict[str, tp.Any] = None,
num_heads: int = 1,
num_head_repeats: int = 1,
splits: list[int] | None = None,
extra_modules: list[nn.Linear] | None = None,
develop_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Smooth two consecutive modules.
Args:
prevs (`nn.Module` or `list[nn.Module]`):
The first module(s).
modules (`nn.Linear` or `list[nn.Linear]`):
The second module(s).
scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The smooth quantization scale.
config (`SmoothCalibConfig` or `None`, *optional*, defaults to `None`):
The smooth quantization configuration.
weight_quantizer (`Quantizer` or `None`, *optional*, defaults to `None`):
The quantizer for weights.
input_quantizer (`Quantizer` or `None`, *optional*, defaults to `None`):
The quantizer for inputs.
weights (`list[nn.Parameter]` or `None`, *optional*, defaults to `None`):
The weights of the modules. If `None`, the weights of the modules will be used.
inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the input activations.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs corresponding to the `eval_module`.
eval_module (`nn.Module`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
eval_kwargs (`dict[str, tp.Any]`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
extra_modules (`list[nn.Module]` or `None`, *optional*, defaults to `None`):
Extra modules to smooth.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
Returns:
`torch.Tensor`:
The smooth quantization scale in CPU.
"""
if not isinstance(modules, (list, tuple)):
modules = [modules]
extra_modules = [] if extra_modules is None else extra_modules
if scale is None:
assert inputs is not None or eval_inputs is not None, "inputs or eval_inputs must be provided"
scale = SmoothLinearCalibrator(
config=config,
weight_quantizer=weight_quantizer,
input_quantizer=input_quantizer,
num_heads=num_heads,
num_head_repeats=num_head_repeats,
develop_dtype=develop_dtype,
).calibrate(
x_wgts=[module.weight for module in modules] if weights is None else weights,
x_acts=inputs,
x_mods=modules,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
splits=splits,
)
gc.collect()
torch.cuda.empty_cache()
upscale = scale
for module in modules + extra_modules:
upscale = upscale.to(device=module.weight.device)
smooth_upscale_param(module.weight, upscale, channels_dim=1)
if prevs is not None:
downscale = convert_smooth_upscale_to_downscale(upscale, num_heads=num_heads, num_head_repeats=num_head_repeats)
if isinstance(prevs, nn.Module):
prevs = [prevs]
for module in prevs:
if module is None:
continue
downscale = downscale.to(device=module.weight.device)
smooth_downscale_param(module.weight, downscale, channels_dim=0)
if hasattr(module, "bias") and module.bias is not None:
smooth_downscale_param(module.bias, downscale, channels_dim=0)
return scale.to(device="cpu")
@torch.inference_mode()
def smooth_attention(
*,
q_proj: nn.Linear,
k_proj: nn.Linear,
scale: torch.Tensor | None,
config: SmoothCalibConfig | None = None,
query_quantizer: Quantizer | None = None,
key_quantizer: Quantizer | None = None,
queries: TensorsCache | None = None,
keys: TensorsCache | None = None,
attn_q: nn.Module | None = None,
attn_k: nn.Module | None = None,
eval_inputs: TensorsCache | None = None,
eval_module: nn.Module = None,
eval_kwargs: dict[str, tp.Any] = None,
num_heads: int = 1,
num_head_repeats: int = 1,
with_rope: bool = True,
develop_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Smooth attention.
Args:
q_proj (`nn.Linear`):
The query projection module.
k_proj (`nn.Linear`):
The key projection module.
scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The smooth quantization scale.
config (`SmoothCalibConfig` or `None`, *optional*, defaults to `None`):
The smooth quantization configuration.
query_quantizer (`Quantizer` or `None`, *optional*, defaults to `None`):
The quantizer for queries.
key_quantizer (`Quantizer` or `None`, *optional*, defaults to `None`):
The quantizer for keys.
queries (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the queries.
keys (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the keys.
attn_q (`nn.Module` or `None`, *optional*, defaults to `None`):
The module that generates the queries.
attn_k (`nn.Module` or `None`, *optional*, defaults to `None`):
The module that generates the keys.
eval_inputs (`TensorsCache` or `None`, *optional*, defaults to `None`):
The cache of the inputs corresponding to the evaluation module.
eval_module (`nn.Module`, *optional*, defaults to `None`):
The module to evaluate the quantization error.
eval_kwargs (`dict[str, tp.Any]`, *optional*, defaults to `None`):
The keyword arguments for evaluation.
num_heads (`int`, *optional*, defaults to `1`):
The number of heads.
num_head_repeats (`int`, *optional*, defaults to `1`):
The number of head repeats.
with_rope (`bool`, *optional*, defaults to `True`):
Whether quantization is applied after rotary position embedding.
develop_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The development data type.
Returns:
`torch.Tensor`:
The smooth quantization scale in CPU.
"""
if scale is None:
assert queries is not None and keys is not None and eval_inputs is not None
assert attn_q is not None and attn_k is not None, "modules must be provided"
scale = SmoothAttentionCalibrator(
config=config,
query_quantizer=query_quantizer,
key_quantizer=key_quantizer,
num_heads=num_heads,
num_head_repeats=num_head_repeats,
with_rope=with_rope,
develop_dtype=develop_dtype,
).calibrate(
q_proj_weight=q_proj.weight,
k_proj_weight=k_proj.weight,
queries=queries,
keys=keys,
query_module=attn_q,
key_module=attn_k,
eval_inputs=eval_inputs,
eval_module=eval_module,
eval_kwargs=eval_kwargs,
)
gc.collect()
torch.cuda.empty_cache()
upscale = scale.to(device=q_proj.weight.device)
smooth_upscale_param(q_proj.weight, upscale, channels_dim=0)
downscale = convert_smooth_upscale_to_downscale(upscale, num_heads=num_heads, num_head_repeats=num_head_repeats)
smooth_downscale_param(k_proj.weight, downscale, channels_dim=0)
return scale.to(device="cpu")
================================================
FILE: deepcompressor/csrc/load.py
================================================
# -*- coding: utf-8 -*-
"""Deepcompressor Extension."""
import os
from torch.utils.cpp_extension import load
__all__ = ["_C"]
dirpath = os.path.dirname(__file__)
_C = load(
name="deepcompressor_C",
sources=[f"{dirpath}/pybind.cpp", f"{dirpath}/quantize/quantize.cu"],
extra_cflags=["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++20"],
extra_cuda_cflags=[
"-O3",
"-std=c++20",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_HALF2_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--allow-expensive-optimizations=true",
"--threads=8",
],
)
================================================
FILE: deepcompressor/csrc/pybind.cpp
================================================
#include
#include
#include "quantize/quantize.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("round_to_nearest_in_codebook_cuda", &round_to_nearest_in_codebook_cuda,
py::arg("tensor"), py::arg("codebook"), py::arg("inplace") = false,
py::arg("bnb") = false, "RTN with codebook (CUDA)");
}
================================================
FILE: deepcompressor/csrc/quantize/quantize.cu
================================================
#include
#include
#include
#include
#include
#include
#include
#include
#include "quantize.h"
// The following code is adapted from the bitsandbytes library:
// https://github.com/bitsandbytes-foundation/bitsandbytes/blob/main/csrc/kernels.cu#L232
template
__device__ __forceinline__
typename std::conditional::type
bnb_nearest_neighbor(float_t x, float_t *codebook, const int C)
{
int mid = (C >> 1) - 1;
int hi = C - 1;
int lo = 0;
float_t lval = codebook[lo];
float_t hval = codebook[hi];
float_t mval = codebook[mid];
for (int step = (C >> 2); step > 0; step >>= 1)
{
if (x > mval)
{
lo = mid;
lval = mval;
mid += step;
}
else
{
hi = mid;
hval = mval;
mid -= step;
}
mval = codebook[mid];
}
if (x > mval)
{
if constexpr (ret_val)
{
return (x - mval > hval - x) ? hval : mval;
}
else
{
return (x - mval > hval - x) ? hi : mid;
}
}
else
{
if constexpr (ret_val)
{
return (x - lval < mval - x) ? lval : mval;
}
else
{
return (x - lval < mval - x) ? lo : mid;
}
}
}
template
__device__ __forceinline__
typename std::conditional::type
nearest_neighbor(float_t x, const float_t *codebook, int C)
{
int lo = 0;
int bit = 1 << (31 - __clz(C));
float_t lval = codebook[lo];
while (bit)
{
int next = lo | bit;
float_t nval = codebook[next];
bool pred = next < C && nval <= x;
lo = pred ? next : lo;
lval = pred ? nval : lval;
bit >>= 1;
}
int hi = lo + (lo < C - 1);
float_t hval = codebook[hi];
if constexpr (ret_val)
{
return (x + x < lval + hval) ? lval : hval;
}
else
{
return (x + x < lval + hval) ? lo : hi;
}
}
// CUDA kernel: Each thread processes one element from x and finds the nearest
// codebook entry. The codebook (of size C < 256) is first loaded into shared
// memory.
template
__global__ void round_to_nearest_in_codebook_kernel(
const float_t *__restrict__ x, const float_t *__restrict__ codebook,
float_t *__restrict__ y, const int N, const int C)
{
// Use a shared memory array for the codebook.
__shared__ float_t s_codebook[256];
// Have the first few threads load the codebook into shared memory.
for (int i = threadIdx.x; i < C; i += blockDim.x)
{
s_codebook[i] = codebook[i];
}
__syncthreads();
// Global index for the element processed by this thread.
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N)
{
if constexpr (bnb)
{
y[idx] = bnb_nearest_neighbor(x[idx], s_codebook, C);
}
else
{
y[idx] = nearest_neighbor(x[idx], s_codebook, C);
}
}
}
torch::Tensor round_to_nearest_in_codebook_cuda(torch::Tensor tensor,
torch::Tensor codebook,
bool inplace, bool bnb)
{
auto x = tensor.contiguous();
auto c = codebook.contiguous();
auto y = inplace ? x : torch::empty_like(tensor);
const int N = x.numel();
const int C = c.numel();
const int threads = 256;
const int blocks = (N + threads - 1) / threads;
AT_DISPATCH_FLOATING_TYPES(
tensor.scalar_type(), "round_to_nearest_in_codebook_cuda", [&]
{
if (bnb && (C & (C - 1)) == 0) {
round_to_nearest_in_codebook_kernel
<<>>(x.data_ptr(),
c.data_ptr(),
y.data_ptr(), N, C);
} else {
round_to_nearest_in_codebook_kernel