Full Code of mit-han-lab/deepcompressor for AI

main 69f3473f5e1c cached
229 files
3.7 MB
970.0k tokens
1198 symbols
1 requests
Download .txt
Showing preview only (3,878K chars total). Download the full file or copy to clipboard to get everything.
Repository: mit-han-lab/deepcompressor
Branch: main
Commit: 69f3473f5e1c
Files: 229
Total size: 3.7 MB

Directory structure:
gitextract_0u4zluxv/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   ├── diffusion/
│   │   └── .gitkeep
│   └── llm/
│       └── .gitkeep
├── deepcompressor/
│   ├── __init__.py
│   ├── app/
│   │   ├── __init__.py
│   │   ├── diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── cache/
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── config.py
│   │   │   ├── dataset/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base.py
│   │   │   │   ├── calib.py
│   │   │   │   ├── collect/
│   │   │   │   │   ├── calib.py
│   │   │   │   │   └── utils.py
│   │   │   │   └── data/
│   │   │   │       ├── COCO/
│   │   │   │       │   ├── COCO.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── DCI/
│   │   │   │       │   ├── DCI.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── MJHQ/
│   │   │   │       │   ├── MJHQ.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── __init__.py
│   │   │   │       └── dump.py
│   │   │   ├── eval/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── config.py
│   │   │   │   └── metrics/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── fid.py
│   │   │   │       ├── image_reward.py
│   │   │   │       ├── multimodal.py
│   │   │   │       ├── run.py
│   │   │   │       └── similarity.py
│   │   │   ├── nn/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── attention.py
│   │   │   │   ├── patch.py
│   │   │   │   └── struct.py
│   │   │   ├── pipeline/
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── ptq.py
│   │   │   ├── quant/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── activation.py
│   │   │   │   ├── config.py
│   │   │   │   ├── quantizer/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── config.py
│   │   │   │   │   └── quantizer.py
│   │   │   │   ├── rotate.py
│   │   │   │   ├── smooth.py
│   │   │   │   ├── utils.py
│   │   │   │   └── weight.py
│   │   │   └── utils.py
│   │   └── llm/
│   │       ├── __init__.py
│   │       ├── cache/
│   │       │   ├── __init__.py
│   │       │   └── config.py
│   │       ├── config.py
│   │       ├── eval/
│   │       │   ├── __init__.py
│   │       │   ├── base.py
│   │       │   ├── config.py
│   │       │   ├── custom.py
│   │       │   ├── lm_eval.py
│   │       │   └── longbench/
│   │       │       ├── __init__.py
│   │       │       ├── eval.py
│   │       │       ├── metrics.py
│   │       │       └── task2prompt.json
│   │       ├── model/
│   │       │   ├── __init__.py
│   │       │   └── config.py
│   │       ├── nn/
│   │       │   ├── __init__.py
│   │       │   ├── patch.py
│   │       │   └── struct.py
│   │       ├── ptq.py
│   │       └── quant/
│   │           ├── __init__.py
│   │           ├── activation.py
│   │           ├── config.py
│   │           ├── dataset.py
│   │           ├── quantizer/
│   │           │   ├── __init__.py
│   │           │   ├── config.py
│   │           │   └── quantizer.py
│   │           ├── reorder.py
│   │           ├── rotate.py
│   │           ├── smooth.py
│   │           ├── utils.py
│   │           └── weight.py
│   ├── backend/
│   │   ├── __init__.py
│   │   ├── nunchaku/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   ├── convert_lora.py
│   │   │   └── utils.py
│   │   ├── qserve/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   └── utils.py
│   │   ├── tinychat/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   ├── csrc/
│   │   │   │   ├── load.py
│   │   │   │   ├── pybind.cpp
│   │   │   │   ├── quantization/
│   │   │   │   │   ├── dequantize.cuh
│   │   │   │   │   ├── gemm/
│   │   │   │   │   │   ├── gemm_cuda.cu
│   │   │   │   │   │   ├── gemm_cuda.h
│   │   │   │   │   │   └── semaphore.h
│   │   │   │   │   └── gemv/
│   │   │   │   │       ├── gemv_cuda.cu
│   │   │   │   │       └── gemv_cuda.h
│   │   │   │   └── utils.cuh
│   │   │   ├── linear.py
│   │   │   └── utils.py
│   │   └── utils.py
│   ├── calib/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── lowrank.py
│   │   │   ├── range.py
│   │   │   ├── reorder.py
│   │   │   ├── rotation.py
│   │   │   ├── search.py
│   │   │   └── smooth.py
│   │   ├── lowrank.py
│   │   ├── metric.py
│   │   ├── range.py
│   │   ├── reorder.py
│   │   ├── rotate.py
│   │   ├── search.py
│   │   └── smooth.py
│   ├── csrc/
│   │   ├── load.py
│   │   ├── pybind.cpp
│   │   └── quantize/
│   │       ├── quantize.cu
│   │       └── quantize.h
│   ├── data/
│   │   ├── __init__.py
│   │   ├── cache.py
│   │   ├── codebook.py
│   │   ├── common.py
│   │   ├── dtype.py
│   │   ├── range.py
│   │   ├── scale.py
│   │   ├── tensor.py
│   │   ├── utils/
│   │   │   ├── __init__.py
│   │   │   ├── dtype.py
│   │   │   ├── reshape.py
│   │   │   ├── scale.py
│   │   │   └── shape.py
│   │   └── zero.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── action.py
│   │   ├── cache.py
│   │   └── config.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── patch/
│   │   │   ├── __init__.py
│   │   │   ├── conv.py
│   │   │   ├── linear.py
│   │   │   ├── lowrank.py
│   │   │   └── sdpa.py
│   │   └── struct/
│   │       ├── __init__.py
│   │       ├── attn.py
│   │       └── base.py
│   ├── quantizer/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── kernel.py
│   │   │   └── lowrank.py
│   │   ├── impl/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── info.py
│   │   │   ├── scale.py
│   │   │   ├── simple.py
│   │   │   └── ste.py
│   │   ├── kernel/
│   │   │   ├── __init__.py
│   │   │   ├── gptq.py
│   │   │   └── rtn.py
│   │   └── processor.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── common.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── model.py
│   │   │   ├── output.py
│   │   │   └── path.py
│   │   ├── dataclass.py
│   │   ├── hooks/
│   │   │   ├── __init__.py
│   │   │   ├── branch.py
│   │   │   ├── hook.py
│   │   │   ├── packager.py
│   │   │   └── processor.py
│   │   ├── math/
│   │   │   ├── __init__.py
│   │   │   ├── functional.py
│   │   │   └── hadamard.py
│   │   ├── patch.py
│   │   └── tools/
│   │       ├── __init__.py
│   │       ├── logging.py
│   │       └── sys.py
│   └── version.py
├── environment.yml
├── examples/
│   ├── diffusion/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── configs/
│   │   │   ├── __default__.yaml
│   │   │   ├── collect/
│   │   │   │   └── qdiff.yaml
│   │   │   ├── lora/
│   │   │   │   ├── __default__.yaml
│   │   │   │   └── flux.1-dev/
│   │   │   │       ├── anime.yaml
│   │   │   │       ├── ghibsky.yaml
│   │   │   │       ├── realism.yaml
│   │   │   │       ├── sketch.yaml
│   │   │   │       └── yarn.yaml
│   │   │   ├── model/
│   │   │   │   ├── flux.1-dev.yaml
│   │   │   │   ├── flux.1-schnell.yaml
│   │   │   │   ├── pixart-sigma.yaml
│   │   │   │   └── sana-1.6b.yaml
│   │   │   ├── svdquant/
│   │   │   │   ├── __default__.yaml
│   │   │   │   ├── fast.yaml
│   │   │   │   ├── gptq.yaml
│   │   │   │   ├── int4.yaml
│   │   │   │   └── nvfp4.yaml
│   │   │   └── text/
│   │   │       ├── __default__.yaml
│   │   │       └── awq.yaml
│   │   ├── prompts/
│   │   │   ├── lora/
│   │   │   │   ├── anime.yaml
│   │   │   │   ├── ghibsky.yaml
│   │   │   │   ├── realism.yaml
│   │   │   │   ├── sketch.yaml
│   │   │   │   └── yarn.yaml
│   │   │   └── qdiff.yaml
│   │   └── scripts/
│   │       └── svdquant.sh
│   └── llm/
│       ├── .gitignore
│       ├── README.md
│       ├── configs/
│       │   ├── __default__.yaml
│       │   ├── awq.yaml
│       │   ├── gptq.yaml
│       │   ├── ooo.yaml
│       │   ├── qoq-g128.yaml
│       │   ├── qoq-gchn.yaml
│       │   ├── smoothquant-dynamic.yaml
│       │   └── smoothquant-static.yaml
│       └── scripts/
│           ├── awq.sh
│           ├── gptq.sh
│           ├── qoq.sh
│           └── smoothquant.sh
└── pyproject.toml

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VS Code
.vscode/
!.vscode/settings.json

.DS_Store
*.log
*.pt
.tmp/
runs
exps
runs/
exps/
wandb
wandb/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [2024] Yujun Lin, Muyang Li, Zhekai Zhang, Haotian Tang, Shang Yang, Song Han

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
<p align="center">
<img src="assets/deepcompressor.png" alt="DeepCompressor Logo" width="450">
</p>

<h2><p align="center">Model Compression Toolbox for Large Language Models and Diffusion Models</p></h2>

<p align="center">
    <a href="https://github.com/mit-han-lab/deepcompressor/blob/master/LICENSE">
        <img alt="Apache License" src="https://img.shields.io/github/license/mit-han-lab/deepcompressor">
    </a>
    <!-- <a href="https://deepcompressor.mit.edu">
        <img alt="Website" src="https://img.shields.io/website?up_message=deepcompressor&url=https%3A%2F%2Fdeepcompressor.mit.edu">
    </a> -->
   <!-- <a href="https://pypi.org/project/deepcompressor/">
        <img alt="Pypi" src="https://img.shields.io/pypi/v/deepcompressor">
    </a> -->
</p>

## News
- [2025/02] 🎉 [**QServe**](https://arxiv.org/abs/2405.04532) has been accepted to MLSys 2025!
- [2025/01] 🎉 [**SVDQuant**](https://arxiv.org/abs/2411.05007) has been accepted to ICLR 2025 (Spotlight)!
- [2024/12] 🎉 [**QServe**](https://github.com/mit-han-lab/qserve) has been integratedd into NVIDIA [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama)!
- [2024/11] 🔥 Our latest **W4A4** diffusion model quantization work [**SVDQuant**](https://arxiv.org/abs/2411.05007) algorithm and [**Nunchaku**](https://github.com/mit-han-lab/nunchaku) system is publicly released! Check our [paper](http://arxiv.org/abs/2411.05007)!
- [2024/05] 🔥 Our latest **W4A8KV4** LLM quantization work **QoQ** algorithm and **QServe** system is publicly released! **QoQ** is short for *quattuor-octō-quattuor* which is 4-8-4 in latin. Check our [paper](https://arxiv.org/abs/2405.04532)!

## Key Features

***DeepCompressor*** is an open source model compression toolbox for large language models and diffusion models based on PyTorch. DeepCompressor currently supports fake quantization with any integer and floating-point data type within 8 bits, e.g., INT8, INT4 and FP4_E2M1. Here are examples that implement the following algorithms.

+ [Post-training quantization for large language models](/examples/llm/):
  + Weight-only Quantization
    + [AWQ (W4A16)](/examples/llm/configs/awq.yaml)
    + [GPTQ (W4A16)](/examples/llm/configs/gptq.yaml)
  + Weight-Activation Quantization
    + [SmoothQuant (W8A8)](/examples/llm/configs/smoothquant-static.yaml)
  + Weight-Activation and KV-Cache Quantization
    + [QoQ (W4A8KV4)](/examples/llm/)
+ [Post-training quantization for diffusion models](/examples/diffusion/):
  + Weight-Activation Quantization
    + [SVDQuant (W4A4)](/examples/diffusion/)

DeepCompressor also contains examples that integrate with other inference libraries.
  + [Deploy weight-only quantized LLMs with TinyChat](/examples/llm/)
  + [Deploy quantized LLMs with QServe]((/examples/llm/))
  + [Deploy quantized diffusion models with Nunchaku](/examples/diffusion/)

## Installation

### Install from Source

1. Clone this repository and navigate to deepcompressor folder
```
git clone https://github.com/mit-han-lab/deepcompressor
cd deepcompressor
```

2. Install Package
```
conda env create -f environment.yml
poetry install
```

## Highlights
### SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models

[[Website](https://hanlab.mit.edu/projects/svdquant)][[Paper](http://arxiv.org/abs/2411.05007)][[Nunchaku Inference System](https://github.com/mit-han-lab/nunchaku)]

Diffusion models have been proven highly effective at generating high-quality images. However, as these models grow larger, they require significantly more memory and suffer from higher latency, posing substantial challenges for deployment. In this work, we aim to accelerate diffusion models by quantizing their weights and activations to 4 bits. At such an aggressive level, both weights and activations are highly sensitive, where conventional post-training quantization methods for large language models like smoothing become insufficient. To overcome this limitation, we propose **SVDQuant**, a new 4-bit quantization paradigm. Different from smoothing which redistributes outliers between weights and activations, our approach absorbs these outliers using a low-rank branch. We first consolidate the outliers by shifting them from activations to weights, then employ a high-precision low-rank branch to take in the weight outliers with Singular Value Decomposition (SVD). This process eases the quantization on both sides. However, naïvely running the low-rank branch independently incurs significant overhead due to extra data movement of activations, negating the quantization speedup. To address this, we co-design an inference engine **Nunchaku** that fuses the kernels of the low-rank branch into those of the low-bit branch to cut off redundant memory access. It can also seamlessly support off-the-shelf low-rank adapters (LoRAs) without the need for re-quantization. Extensive experiments on SDXL, PixArt-∑, and FLUX.1 validate the effectiveness of SVDQuant in preserving image quality. We reduce the memory usage for the 12B FLUX.1 models by 3.5×, achieving 3.0× speedup over the 4-bit weight-only quantized baseline on the 16GB laptop 4090 GPU, paving the way for more interactive applications on PCs.

![Teaser](/assets/diffusion/svdquant/teaser.jpg)
![SVDQuant](/assets/diffusion/svdquant/svdquant.gif)

#### Quality Evaluation

Below is the quality and similarity evaluated with 5000 samples from MJHQ-30K dataset. IR means ImageReward. Our 4-bit results outperform other 4-bit baselines, effectively preserving the visual quality of 16-bit models.

| Model                      | Precision |  Method   | FID ($\downarrow$) | IR ($\uparrow$) | LPIPS ($\downarrow$) | PSNR( $\uparrow$) |
|----------------------------|-----------|-----------|--------------------|-----------------|----------------------|-------------------|
| FLUX.1-dev (50 Steps)      | BF16      | --        | 20.3               | 0.953           | --                   | --                |
|                            | W4A16     | NF4       | 20.6               | 0.910           | 0.272                | 19.5              |
|                            | INT W4A4  |           | 20.2               | 0.908           | 0.322                | 18.5              |
|                            | INT W4A4  | SVDQuant  | 19.9               | 0.935           | 0.223                | 21.0              |
|                            | NVFP4     |           | 20.3               | 0.961           | 0.345                | 16.3              |
|                            | NVFP4     | SVDQuant  | 20.3               | 0.945           | 0.205                | 21.5              |
| FLUX.1-schnell (4 Steps)   | BF16      | --        | 19.2               | 0.938           | --                   | --                |
|                            | W4A16     | NF4       | 18.9               | 0.943           | 0.257                | 18.2              |
|                            | INT W4A4  |           | 18.1               | 0.962           | 0.345                | 16.3              |
|                            | INT W4A4  | SVDQuant  | 18.3               | 0.951           | 0.257                | 18.3              |
|                            | NVFP4     |           | 19.0               | 0.952           | 0.276                | 17.6              |
|                            | NVFP4     | SVDQuant  | 18.9               | 0.966           | 0.228                | 19.0              |
| SANA-1.6b (20 Steps)       | BF16      | --        | 20.6               | 0.952           | --                   | --                |
|                            | INT W4A4  |           | 20.5               | 0.894           | 0.339                | 15.3              |
|                            | INT W4A4  | SVDQuant  | 19.3               | 0.935           | 0.220                | 17.8              |
|                            | NVFP4     |           | 19.7               | 0.929           | 0.236                | 17.4              |
|                            | NVFP4     | SVDQuant  | 20.2               | 0.941           | 0.176                | 19.0              |
| PixArt-Sigma (20 Steps)    | FP16      | --        | 16.6               | 0.944           | --                   | --                |
|                            | INT W4A8  | ViDiT-Q   | 37.3               | 0.573           | 0.611                | 12.0              |
|                            | INT W4A4  | SVDQuant  | 19.2               | 0.878           | 0.323                | 17.6              |
|                            | NVFP4     |           | 31.8               | 0.660           | 0.517                | 14.8              |
|                            | NVFP4     | SVDQuant  | 16.6               | 0.940           | 0.271                | 18.5              |

### QServe: W4A8KV4 Quantization for Efficient LLM Serving

[[Website](https://hanlab.mit.edu/projects/qserve)][[Paper](https://arxiv.org/abs/2405.04532)][[QoQ Algorithm Code](/examples/llm)][[QServe GPU System](https://github.com/mit-han-lab/qserve)]

Quantization can accelerate large language model (LLM) inference. Going beyond INT8 quantization, the research community is actively exploring even lower precision, such as INT4. Nonetheless, state-of-the-art INT4 quantization techniques only accelerate low-batch, edge LLM inference, failing to deliver performance gains in large-batch, cloud-based LLM serving. We uncover a critical issue: existing INT4 quantization methods suffer from significant runtime overhead (20-90%) when **dequantizing either weights or partial sums** on GPUs. To address this challenge, we introduce **QoQ**, a W4A8KV4 quantization algorithm with 4-bit weight, 8-bit activation, and 4-bit KV cache. QoQ stands for **quattuor-octo-quattuor**, which represents 4-8-4 in Latin. QoQ is implemented by the **QServe** inference library that achieves measured speedup. The key insight driving QServe is that the efficiency of LLM serving on GPUs is critically influenced by **operations on low-throughput CUDA cores**. Building upon this insight, in QoQ algorithm, we introduce progressive quantization that can allow low dequantization overhead in W4A8 GEMM. Additionally, we develop SmoothAttention to effectively mitigate the accuracy degradation incurred by 4-bit KV quantization. In the QServe system, we perform compute-aware weight reordering and take advantage of register-level parallelism to reduce dequantization latency. We also make fused attention memory-bound, harnessing the performance gain brought by KV4 quantization. As a result, QServe improves the maximum achievable serving throughput of Llama-3-8B by **1.2×** on A100, **1.4×** on L40S; and Qwen1.5-72B by **2.4×** on A100, **3.5×** on L40S, compared to TensorRT-LLM.

![QoQ-QServe](/assets/llm/qoq/qoq-qserve.png)
![QoQ](/assets/llm/qoq/qoq.png)


#### Perplexity Evaluation

Below is the WikiText2 perplexity evaluated with 2048 sequence length. The lower is the better.

|   Methods   |  Precision   | Llama-3.1 70B | Llama-3.1 8B | Llama-3 70B |  Llama-3 8B | Llama-2 7B | Llama-2 13B | Llama-2 70B | Llama 7B | Llama 13B | Llama 30B | Mistral 7B | Yi 34B |
|-------------|--------------|---------------|--------------|-------------| ------------|------------|-------------|-------------|----------|-----------|-----------|------------|--------|
| FP16        |              | 2.81          | 6.24         | 2.85        |  6.14       | 5.47       | 4.88        | 3.32        | 5.68     | 5.09      | 4.10      | 5.25       | 4.60   |
| SmoothQuant | W8A8         | 3.23          | 6.38         | 3.14        |  6.28       | 5.54       | 4.95        | 3.36        | 5.73     | 5.13      | 4.23      | 5.29       | 4.69   |
| GPTQ-R      | W4A16 g128   | 3.46          | 6.64         | 3.42        |  6.56       | 5.63       | 4.99        | 3.43        | 5.83     | 5.20      | 4.22      | 5.39       | 4.68   |
| AWQ         | W4A16 g128   | 3.22          | 6.60         | 3.20        |  6.54       | 5.60       | 4.97        | 3.41        | 5.78     | 5.19      | 4.21      | 5.37       | 4.67   |
| QuaRot      | W4A4         | 5.97          | 8.32         | 6.75        |  8.33       | 6.19       | 5.45        | 3.83        | 6.34     | 5.58      | 4.64      | 5.77       | -      |
| SpinQuant   | W4A4         | 4.80          | 7.42         | 6.27        |  7.37       | 5.96       | 5.24        | 3.71        | 6.14     | 5.39      | 4.56      | -          | -      |
| Atom        | W4A4 g128    | -             | -            | 4.33        |  7.78       | 6.12       | 5.31        | 3.73        | 6.25     | 5.52      | 4.61      | 5.76       | 4.97   |
| QoQ         | W4A8KV4      | 3.68          | 6.87         | 3.65        |  6.81       | 5.75       | 5.11        | 3.50        | 5.92     | 5.27      | 4.31      | 5.44       | 4.73   |
| QoQ         | W4A8KV4 g128 | 3.51          | 6.77         | 3.50        |  6.70       | 5.67       | 5.06        | 3.46        | 5.88     | 5.23      | 4.27      | 5.41       | 4.73   |

\* SmoothQuant is evaluated with per-tensor static KV cache quantization.

#### Efficiency Benchmarks

When serving the large language models Llama-3-8B and Qwen1.5-72B on L40S and A100 GPUs, QServe demonstrates superior performance, achieving **1.2x-1.4x higher throughput** compared to the leading industry solution, TensorRT-LLM, for Llama-3-8B, and a **2.4x-3.5x higher throughput** for Qwen1.5-72B.

See more about benchmarking setting in [QServe GPU Inference System](https://github.com/mit-han-lab/qserve).
| L40S (48G)           | Llama-3-8B | Llama-2-7B | Mistral-7B | Llama-2-13B | Llama-30B | Yi-34B    | Llama-2-70B | Qwen-1.5-72B |
|----------------------|------------|------------|------------|-------------|-----------|-----------|-------------|--------------|
| TRT-LLM-FP16         | 1326       | 444        | 1566       | 92          | OOM       | OOM       | OOM         | OOM          |
| TRT-LLM-W4A16        | 1431       | 681        | 1457       | 368         | 148       | 313       | 119         | 17           |
| TRT-LLM-W8A8         | 2634       | 1271       | 2569       | 440         | 123       | 364       | OOM         | OOM          |
| Atom-W4A4            | --         | 2120       | --         | --          | --        | --        | --          | --           |
| QuaRot-W4A4          | --         | 805        | --         | 413         | 133       | --        | --          | 15           |
| QServe-W4A8KV4       | **3656**   | **2394**   | **3774**   | **1327**    | **504**   | **869**   | **286**     | **59**       |
| Throughput Increase* | **1.39x**  | **1.13x**  | **1.47x**  | **3.02x**   | **3.41x** | **2.39x** | **2.40x**   | **3.47x**    |

| A100 (80G)           | Llama-3-8B | Llama-2-7B | Mistral-7B | Llama-2-13B | Llama-30B | Yi-34B    | Llama-2-70B | Qwen-1.5-72B |
|----------------------|------------| -----------|------------|-------------|-----------|-----------|-------------|--------------|
| TRT-LLM-FP16         | 2503       | 1549       | 2371       | 488         | 80        | 145       | OOM         | OOM          |
| TRT-LLM-W4A16        | 2370       | 1549       | 2403       | 871         | 352       | 569       | 358         | 143          |
| TRT-LLM-W8A8         | 2396       | 2334       | 2427       | 1277        | 361       | 649       | 235         | 53           |
| Atom-W4A4            | --         | 1160       | --         | --          | --        | --        | --          | --           |
| QuaRot-W4A4          | --         | 1370       | --         | 289         | 267       | --        | --          | 68           |
| QServe-W4A8KV4       | **3005**   | **2908**   | **2970**   | **1741**    | **749**   | **803**   | **419**     | **340**      |
| Throughput Increase* | **1.20x**  | **1.25x**  | **1.22x**  | **1.36x**   | **2.07x** | **1.23x** | **1.17x**   | **2.38x**    |

The absolute token generation throughputs of QServe and baseline systems (Unit: tokens/second. `--` means unsupported). All experiments were conducted under the same device memory budget. Throughput increase of QServe is calculated with regard to the best baseline in each column.

## Reference

If you find `deepcompressor` useful or relevant to your research, please kindly cite our paper:

```bibtex
@article{lin2024qserve,
  title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
  author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and Han, Song},
  journal={arXiv preprint arXiv:2405.04532},
  year={2024}
}

@article{
  li2024svdquant,
  title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
  author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
  journal={arXiv preprint arXiv:2411.05007},
  year={2024}
}
```

## Related Projects

The following projects are highly related to QServe. Our group has developed full-stack application-algorithm-system-hardware support for efficient large models, receiving **9k+ GitHub stars** and **over 1M Huggingface community downloads**.

You are also welcome to check out [MIT HAN Lab](https://hanlab.mit.edu) for other exciting projects on **Efficient Generative AI**!

- [**System**] [QServe: W4A8KV4 Quantization for Efficient LLM Serving](https://github.com/mit-han-lab/qserve)

- [**System**] [TinyChat: Efficient and Lightweight Chatbot with AWQ](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat)

- [**Application**] [VILA: On Pretraining of Visual-Language Models](https://github.com/Efficient-Large-Model/VILA)

- [**Algorithm**] [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://github.com/mit-han-lab/llm-awq)

- [**Algorithm**] [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://github.com/mit-han-lab/smoothquant)

- [**Algorithm**] [DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models](https://github.com/mit-han-lab/distrifuser)

- [**Hardware**] [SpAtten: Efficient Sparse Attention Architecture with Cascade Token and Head Pruning](https://arxiv.org/abs/2012.09852)


## Acknowledgments

DeepCompressor is inspired by many open-source libraries, including (but not limited to) [GPTQ](https://arxiv.org/abs/2210.17323), [QuaRot](https://arxiv.org/abs/2404.00456) and [Atom](https://arxiv.org/abs/2310.19102). 

================================================
FILE: assets/diffusion/.gitkeep
================================================


================================================
FILE: assets/llm/.gitkeep
================================================


================================================
FILE: deepcompressor/__init__.py
================================================
from .version import __version__  # noqa: F401


================================================
FILE: deepcompressor/app/__init__.py
================================================


================================================
FILE: deepcompressor/app/diffusion/__init__.py
================================================


================================================
FILE: deepcompressor/app/diffusion/cache/__init__.py
================================================
from .config import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig


================================================
FILE: deepcompressor/app/diffusion/cache/config.py
================================================
# -*- coding: utf-8 -*-
"""LLM quantization cache configuration."""

import functools
import re
import typing as tp
from dataclasses import dataclass, field

from omniconfig import configclass

from deepcompressor.utils.config.path import BasePathConfig

from ..nn.struct import DiffusionModelStruct

__all__ = ["DiffusionQuantCacheConfig", "DiffusionPtqCacheConfig"]


@dataclass
class DiffusionQuantCacheConfig(BasePathConfig):
    """Denoising diffusion model quantization cache path.

    Args:
        smooth (`str`, *optional*, default=`""`):
            The smoothing scales cache path.
        branch (`str`, *optional*, default=`""`):
            The low-rank branches cache path.
        wgts (`str`, *optional*, default=`""`):
            The weight quantizers state dict cache path.
        acts (`str`, *optional*, default=`""`):
            The activation quantizers state dict cache path
    """

    smooth: str = ""
    branch: str = ""
    wgts: str = ""
    acts: str = ""

    @staticmethod
    def simplify_path(path: str, key_map: dict[str, set[str]]) -> str:
        """Simplify the cache path."""
        to_replace = {}
        # we first extract all the parts matching the pattern "(skip|include).\[[a-zA-Z0-9_\+]+\]"
        for part in re.finditer(r"(skip|include)\.\[[a-zA-Z0-9_\+]+\]", path):
            # remove the "skip." or "include." prefix
            part = part.group(0)
            if part[0] == "s":
                prefix, keys = part[:4], part[6:-1]
            else:
                prefix, keys = part[:7], part[9:-1]
            # simplify the keys
            keys = "+".join(
                (
                    "".join((s[0] for s in x.split("_")))
                    for x in DiffusionModelStruct._simplify_keys(keys.split("+"), key_map=key_map)
                )
            )
            to_replace[part] = f"{prefix}.[{keys}]"
        # we then replace the parts
        for key, value in to_replace.items():
            path = path.replace(key, value)
        return path

    def simplify(self, key_map: dict[str, set[str]]) -> tp.Self:
        """Simplify the cache paths."""
        return self.apply(functools.partial(self.simplify_path, key_map=key_map))


@configclass
@dataclass
class DiffusionPtqCacheConfig:
    root: str
    dirpath: DiffusionQuantCacheConfig = field(init=False)
    path: DiffusionQuantCacheConfig = field(init=False)


================================================
FILE: deepcompressor/app/diffusion/config.py
================================================
# -*- coding: utf-8 -*-
"""Top-level config of post-training quantization for a diffusion model."""

import os
from dataclasses import dataclass, field

import diffusers.training_utils
import omniconfig
import torch
from omniconfig import ConfigParser, configclass

from deepcompressor.app.llm.config import LlmCacheConfig, LlmQuantConfig
from deepcompressor.data.utils import ScaleUtils
from deepcompressor.utils.config.output import OutputConfig

from .cache import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig
from .eval import DiffusionEvalConfig
from .nn.struct import DiffusionModelStruct
from .pipeline import DiffusionPipelineConfig
from .quant import DiffusionQuantConfig

__all__ = [
    "DiffusionPtqRunConfig",
    "DiffusionPtqCacheConfig",
    "DiffusionQuantCacheConfig",
    "DiffusionEvalConfig",
    "DiffusionPipelineConfig",
    "DiffusionQuantConfig",
]


@configclass
@dataclass
class DiffusionPtqRunConfig:
    """Top-level config of post-training quantization for a diffusion model.

    Args:
        cache (`DiffusionPtqCacheConfig`):
            The cache configuration.
        output (`OutputConfig`):
            The output directory configuration.
        pipeline (`DiffusionPipelineConfig`):
            The diffusion pipeline configuration
        eval (`DiffusionEvalConfig`):
            The evaluation configuration.
        quant (`DiffusionQuantConfig`):
            The post-training quantization configuration.
        seed (`int`, *optional*, defaults to `12345`):
            The seed for reproducibility.
        skip_gen (`bool`, *optional*, defaults to `False`):
            Whether to skip generation.
        skip_eval (`bool`, *optional*, defaults to `False`):
            Whether to skip evaluation.
        load_model (`str`, *optional*, defaults to `""`):
            Directory path to load the model checkpoint.
        save_model (`str`, *optional*, defaults to `""`):
            Directory path to save the model checkpoint.
        copy_on_save (`bool`, *optional*, defaults to `False`):
            Whether to copy the quantization cache on save.
    """

    cache: DiffusionPtqCacheConfig | None
    output: OutputConfig
    pipeline: DiffusionPipelineConfig
    eval: DiffusionEvalConfig
    quant: DiffusionQuantConfig = field(metadata={omniconfig.ARGPARSE_KWARGS: {"prefix": ""}})
    text: LlmQuantConfig | None = None
    text_cache: LlmCacheConfig = field(default_factory=LlmCacheConfig)
    seed: int = 12345
    skip_gen: bool = False
    skip_eval: bool = False
    load_from: str = ""
    save_model: str = ""
    copy_on_save: bool = False

    def __post_init__(self):
        # region set text encoder quanatization scale default dtype
        if self.text is not None and self.text.enabled_wgts:
            self.text.wgts.scale_dtypes = tuple(
                ScaleUtils.infer_scale_dtypes(self.text.wgts.scale_dtypes, default_dtype=self.pipeline.dtype)
            )
        if self.text is not None and self.text.enabled_ipts:
            self.text.ipts.scale_dtypes = tuple(
                ScaleUtils.infer_scale_dtypes(self.text.ipts.scale_dtypes, default_dtype=self.pipeline.dtype)
            )
        if self.text is not None and self.text.enabled_opts:
            self.text.opts.scale_dtypes = tuple(
                ScaleUtils.infer_scale_dtypes(self.text.opts.scale_dtypes, default_dtype=self.pipeline.dtype)
            )
        # endregion
        self.eval.num_gpus = min(torch.cuda.device_count(), self.eval.num_gpus)
        if self.eval.batch_size_per_gpu is None:
            self.eval.batch_size_per_gpu = max(1, self.eval.batch_size // self.eval.num_gpus)
            self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus
        else:
            self.eval.batch_size = self.eval.batch_size_per_gpu * self.eval.num_gpus
        # region setup calib dataset path
        self.quant.calib.path = self.quant.calib.path.format(
            dtype=self.pipeline.dtype,
            family=self.pipeline.family,
            model=self.pipeline.name,
            protocol=self.eval.protocol,
            data=self.quant.calib.data,
        )
        if self.quant.calib.path:
            self.quant.calib.path = os.path.abspath(os.path.expanduser(self.quant.calib.path))
        # endregion
        # region setup eval reference root
        self.eval.ref_root = self.eval.ref_root.format(
            dtype=self.pipeline.dtype,
            family=self.pipeline.family,
            model=self.pipeline.name,
            protocol=self.eval.protocol,
        )
        if self.eval.ref_root:
            self.eval.ref_root = os.path.abspath(os.path.expanduser(self.eval.ref_root))
        # endregion
        # region setup cache directory
        if self.cache is not None:
            if self.quant.enabled_wgts or self.quant.enabled_ipts or self.quant.enabled_opts:
                self.cache.dirpath = self.quant.generate_cache_dirpath(
                    root=self.cache.root, shift=self.pipeline.shift_activations, default_dtype=self.pipeline.dtype
                )
                self.cache.path = self.cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt")
            else:
                self.cache.dirpath = self.cache.path = None
        if self.text is not None and self.text.is_enabled():
            if not self.text_cache.root:
                self.text_cache.root = os.path.join(self.cache.root, "diffusion")
            self.text_cache.dirpath = self.text.generate_cache_dirpath(root=self.text_cache.root, seed=self.seed)
            self.text_cache.path = self.text_cache.dirpath.clone().add_children(f"{self.pipeline.name}.pt")
        # endregion
        # region setup output directory
        if self.output.dirname == "reference":
            assert self.eval.ref_root
            self.output.job = f"run-{self.eval.num_samples}"
            self.output.dirpath = self.eval.ref_root
            self.eval.ref_root = ""
            self.eval.gen_root = "{output}"
        else:
            if self.output.dirname == "default":
                self.output.dirname = self.generate_default_dirname()
            calib_dirname = self.quant.generate_calib_dirname() or "-"
            self.output.dirpath = os.path.join(
                self.output.root,
                "diffusion",
                self.pipeline.family,
                self.pipeline.name,
                *self.quant.generate_dirnames(default_dtype=self.pipeline.dtype)[:-1],
                calib_dirname,
                self.output.dirname,
            )
        if (self.eval.chunk_start > 0 or self.eval.chunk_step > 1) and not self.eval.chunk_only:
            self.output.job += f".c{self.eval.chunk_start}.{self.eval.chunk_step}"
        # endregion
        diffusers.training_utils.set_seed(self.seed)

    def generate_default_dirname(self) -> str:
        name = "-shift" if self.pipeline.shift_activations else ""
        if self.quant.is_enabled():
            name += f"-{self.quant.generate_default_dirname()}"
        if self.text is not None and self.text.is_enabled():
            name += f"-text-{self.text.generate_default_dirname()}"
        size_name = ""
        if self.eval.height:
            size_name += f".h{self.eval.height}"
        if self.eval.width:
            size_name += f".w{self.eval.width}"
        if size_name:
            name += f"-{size_name[1:]}"
        sampling_name = ""
        if self.eval.num_steps is not None:
            sampling_name += f".t{self.eval.num_steps}"
        if self.eval.guidance_scale is not None:
            sampling_name += f".g{self.eval.guidance_scale}"
        if sampling_name:
            name += f"-{sampling_name[1:]}"
        if self.eval.num_samples != -1:
            name += f"-s{self.eval.num_samples}"
            if self.eval.chunk_only:
                name += f".c{self.eval.chunk_start}.{self.eval.chunk_step}"
        assert name[0] == "-"
        return name[1:]

    @classmethod
    def get_parser(cls) -> ConfigParser:
        """Get a parser for post-training quantization of a diffusion model.

        Returns:
            `ConfigParser`:
                A parser for post-training quantization of a diffusion model.
        """
        parser = ConfigParser("Diffusion Run configuration")
        DiffusionQuantConfig.set_key_map(DiffusionModelStruct._get_default_key_map())
        parser.add_config(cls)
        return parser


================================================
FILE: deepcompressor/app/diffusion/dataset/__init__.py
================================================
# -*- coding: utf-8 -*-

from .base import DiffusionDataset
from .calib import DiffusionCalibCacheLoader, DiffusionCalibCacheLoaderConfig


================================================
FILE: deepcompressor/app/diffusion/dataset/base.py
================================================
# -*- coding: utf-8 -*-
"""Dataset for diffusion models."""

import os
import random
import typing as tp

import numpy as np
import torch
import torch.utils.data
from torch.nn import functional as F

from deepcompressor.utils.common import tree_collate, tree_map

__all__ = ["DiffusionDataset"]


class DiffusionDataset(torch.utils.data.Dataset):
    path: str
    filenames: list[str]
    filepaths: list[str]

    def __init__(self, path: str, num_samples: int = -1, seed: int = 0, ext: str = ".npy") -> None:
        if os.path.exists(path):
            self.path = path
            if "caches" in os.listdir(path):
                path = os.path.join(path, "caches")
            filenames = [f for f in sorted(os.listdir(path)) if f.endswith(ext)]
            if num_samples > 0 and num_samples < len(filenames):
                random.Random(seed).shuffle(filenames)
                filenames = filenames[:num_samples]
                filenames = sorted(filenames)
            self.filenames = filenames
            self.filepaths = [os.path.join(path, f) for f in filenames]
        else:
            raise ValueError(f"Invalid data path: {path}")

    def __len__(self) -> int:
        return len(self.filepaths)

    def __getitem__(self, idx) -> dict[str, tp.Any]:
        data = np.load(self.filepaths[idx], allow_pickle=True).item()
        if isinstance(data["input_args"][0], str):
            name = data["input_args"][0]
            latent = np.load(os.path.join(self.path, "latents", name))
            data["input_args"][0] = latent
        if isinstance(data["input_kwargs"]["encoder_hidden_states"], str):
            name = data["input_kwargs"]["encoder_hidden_states"]
            text_emb = np.load(os.path.join(self.path, "text_embs", name))
            data["input_kwargs"]["encoder_hidden_states"] = text_emb
        data = tree_map(lambda x: torch.from_numpy(x), data)

        # Pad encoder_hidden_states to 300 for pixart
        if "encoder_attention_mask" in data["input_kwargs"]:
            encoder_attention_mask = data["input_kwargs"]["encoder_attention_mask"]
            encoder_hidden_states = data["input_kwargs"]["encoder_hidden_states"]
            encoder_hidden_states = F.pad(
                encoder_hidden_states,
                (0, 0, 0, encoder_attention_mask.shape[1] - encoder_hidden_states.shape[1]),
            )
            data["input_kwargs"]["encoder_hidden_states"] = encoder_hidden_states

        return data

    def build_loader(self, **kwargs) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(self, collate_fn=tree_collate, **kwargs)


================================================
FILE: deepcompressor/app/diffusion/dataset/calib.py
================================================
# -*- coding: utf-8 -*-
"""Calibration dataset for diffusion models."""

import random
import typing as tp
from collections import OrderedDict
from dataclasses import MISSING, dataclass

import torch
import torch.nn as nn
import torch.utils.data
from diffusers.models.attention import JointTransformerBlock
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import (
    FluxSingleTransformerBlock,
    FluxTransformerBlock,
)
from omniconfig import configclass

from deepcompressor.data.cache import (
    IOTensorsCache,
    ModuleForwardInput,
    TensorCache,
    TensorsCache,
)
from deepcompressor.data.utils.reshape import AttentionInputReshapeFn, LinearReshapeFn
from deepcompressor.dataset.action import CacheAction, ConcatCacheAction
from deepcompressor.dataset.cache import BaseCalibCacheLoader
from deepcompressor.dataset.config import BaseDataLoaderConfig

from ..nn.struct import DiffusionBlockStruct, DiffusionModelStruct
from .base import DiffusionDataset

__all__ = [
    "DiffusionCalibCacheLoaderConfig",
    "DiffusionCalibDataset",
    "DiffusionConcatCacheAction",
    "DiffusionCalibCacheLoader",
]


@configclass
@dataclass(kw_only=True)
class DiffusionCalibCacheLoaderConfig(BaseDataLoaderConfig):
    """Configuration for collecting calibration dataset for quantization.

    Args:
        data (`str`):
            Dataset name.
        num_samples (`int`):
            Number of dataset samples.
        batch_size (`int`):
            Batch size when loading dataset.
        path (`str`):
            Path to the dataset directory.
        num_workers (`int`):
            Number of workers for data loading.
    """

    path: str
    num_workers: int = 8

    def build_dataset(self) -> "DiffusionCalibDataset":
        """Build the calibration dataset."""
        return DiffusionCalibDataset(self.path, num_samples=self.num_samples)

    def build_loader(self) -> "DiffusionCalibCacheLoader":
        """Build the data loader."""
        return DiffusionCalibCacheLoader(self)


class DiffusionCalibDataset(DiffusionDataset):
    data: list[dict[str, tp.Any]]

    def __init__(self, path: str, num_samples: int = -1, seed: int = 0) -> None:
        super().__init__(path, num_samples=num_samples, seed=seed, ext=".pt")
        data = [torch.load(path) for path in self.filepaths]
        random.Random(seed).shuffle(data)
        self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx) -> dict[str, tp.Any]:
        return self.data[idx]


class DiffusionConcatCacheAction(ConcatCacheAction):
    def info(
        self,
        name: str,
        module: nn.Module,
        tensors: dict[int | str, torch.Tensor],
        cache: TensorsCache,
    ) -> None:
        """Update cache information.

        Args:
            name (`str`):
                Module name.
            module (`nn.Module`):
                Module.
            tensors (`dict[int | str, torch.Tensor]`):
                Tensors to cache.
            cache (`TensorsCache`):
                Cache.
        """
        if isinstance(module, Attention):
            encoder_hidden_states = tensors.get("encoder_hidden_states", None)
            if encoder_hidden_states is None:
                tensors.pop("encoder_hidden_states", None)
                cache.tensors.pop("encoder_hidden_states", None)
            else:
                encoder_hidden_states_cache = cache.tensors["encoder_hidden_states"]
                encoder_channels_dim = 1 if encoder_hidden_states.dim() == 4 else -1
                if encoder_hidden_states_cache.channels_dim is None:
                    encoder_hidden_states_cache.channels_dim = encoder_channels_dim
                    if encoder_channels_dim == -1:
                        encoder_hidden_states_cache.reshape = LinearReshapeFn()
                    else:
                        encoder_hidden_states_cache.reshape = AttentionInputReshapeFn(encoder_channels_dim)
                else:
                    assert encoder_hidden_states_cache.channels_dim == encoder_channels_dim
            hidden_states, hidden_states_cache = tensors["hidden_states"], cache.tensors["hidden_states"]
            channels_dim = 1 if hidden_states.dim() == 4 else -1
            if hidden_states_cache.channels_dim is None:
                hidden_states_cache.channels_dim = channels_dim
                if channels_dim == -1:
                    hidden_states_cache.reshape = LinearReshapeFn()
                else:
                    hidden_states_cache.reshape = AttentionInputReshapeFn(channels_dim)
            else:
                assert hidden_states_cache.channels_dim == channels_dim
        return super().info(name, module, tensors, cache)


class DiffusionCalibCacheLoader(BaseCalibCacheLoader):
    config: DiffusionCalibCacheLoaderConfig
    dataset: DiffusionCalibDataset

    def __init__(self, config: DiffusionCalibCacheLoaderConfig) -> None:
        """Initialize the cache for the diffusion calibration dataset.

        Args:
            config (`DiffusionCalibCacheLoaderConfig`):
                Configuration for the calibration cache loader.
        """
        super().__init__(dataset=config.build_dataset(), batch_size=config.batch_size)
        self.batch_size = min(config.batch_size, len(self.dataset))
        self.config = config

    def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
        """Initialize cache.

        Args:
            name (`str`):
                Module name.
            module (`nn.Module`):
                Module.

        Returns:
            `IOTensorsCache`:
                Cache for inputs and outputs.
        """
        if isinstance(module, FluxSingleTransformerBlock):
            return IOTensorsCache(
                inputs=TensorsCache(
                    OrderedDict(
                        hidden_states=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
                        temb=TensorCache(channels_dim=1, reshape=LinearReshapeFn()),
                    )
                ),
                outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
            )
        elif isinstance(module, Attention):
            return IOTensorsCache(
                inputs=TensorsCache(
                    OrderedDict(
                        hidden_states=TensorCache(channels_dim=None, reshape=None),
                        encoder_hidden_states=TensorCache(channels_dim=None, reshape=None),
                    ),
                ),
                outputs=TensorCache(channels_dim=None, reshape=None),
            )
        else:
            return super()._init_cache(name, module)

    def iter_samples(self) -> tp.Generator[ModuleForwardInput, None, None]:
        dataloader = self.dataset.build_loader(
            batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=self.config.num_workers
        )
        for data in dataloader:
            yield ModuleForwardInput(args=data["input_args"], kwargs=data["input_kwargs"])

    def _convert_layer_inputs(
        self, m: nn.Module, args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], save_all: bool = False
    ) -> ModuleForwardInput:
        """Convert layer inputs to module forward input.

        Args:
            m (`nn.Module`):
                Layer.
            args (`tuple[Any, ...]`):
                Layer input arguments.
            kwargs (`dict[str, Any]`):
                Layer input keyword arguments.
            save_all (`bool`, *optional*, defaults to `False`):
                Whether to save all inputs.

        Returns:
            `ModuleForwardInput`:
                Module forward input.
        """
        kwargs = {k: v for k, v in kwargs.items()}  # noqa: C416
        if "res_hidden_states_tuple" in kwargs:
            kwargs["res_hidden_states_tuple"] = None
        if "hidden_states" in kwargs:
            hidden_states = kwargs.pop("hidden_states")
            assert len(args) == 0, f"Invalid args: {args}"
        else:
            hidden_states = args[0]
        if isinstance(m, (FluxTransformerBlock, JointTransformerBlock)):
            if "encoder_hidden_states" in kwargs:
                encoder_hidden_states = kwargs.pop("encoder_hidden_states")
            else:
                encoder_hidden_states = args[1]
            return ModuleForwardInput(
                args=[
                    hidden_states.detach().cpu() if save_all else MISSING,
                    encoder_hidden_states.detach().cpu() if save_all else MISSING,
                ],
                kwargs=kwargs,
            )
        else:
            return ModuleForwardInput(
                args=[hidden_states.detach().cpu() if save_all else MISSING, *args[1:]], kwargs=kwargs
            )

    def _convert_layer_outputs(self, m: nn.Module, outputs: tp.Any) -> dict[str | int, tp.Any]:
        """Convert layer outputs to dictionary for updating the next layer inputs.

        Args:
            m (`nn.Module`):
                Layer.
            outputs (`Any`):
                Layer outputs.

        Returns:
            `dict[str | int, Any]`:
                Dictionary for updating the next layer inputs.
        """
        if isinstance(m, (FluxTransformerBlock, JointTransformerBlock)):
            assert isinstance(outputs, tuple) and len(outputs) == 2
            encoder_hidden_states, hidden_states = outputs
            return {0: hidden_states.detach().cpu(), 1: encoder_hidden_states.detach().cpu()}
        else:
            return super()._convert_layer_outputs(m, outputs)

    def iter_layer_activations(  # noqa: C901
        self,
        model: nn.Module | DiffusionModelStruct,
        *args,
        needs_inputs_fn: tp.Callable[[str, nn.Module], bool],
        needs_outputs_fn: tp.Callable[[str, nn.Module], bool] | None = None,
        action: CacheAction | None = None,
        skip_pre_modules: bool = True,
        skip_post_modules: bool = True,
        **kwargs,
    ) -> tp.Generator[
        tuple[
            str,
            tuple[
                DiffusionBlockStruct | nn.Module,
                dict[str, IOTensorsCache],
                dict[str, tp.Any],
            ],
        ],
        None,
        None,
    ]:
        """Iterate over model activations in layers.

        Args:
            model (`nn.Module`):
                Model.
            action (`CacheAction`):
                Action for caching activations.
            needs_inputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `True`):
                Function for determining whether to cache inputs for a module given its name and itself.
            needs_outputs_fn (`Callable[[str, nn.Module], bool]` or `bool` or `None`, *optional*, defaults to `None`):
                Function for determining whether to cache outputs for a module given its name and itself.
            *args: Arguments for ``iter_samples``.
            **kwargs: Keyword arguments for ``iter_samples``.

        Yields:
            Generator[
                tuple[str, tuple[DiffusionBlockStruct | nn.Module, dict[str, IOTensorsCache], dict[str, tp.Any]]],
                None,
                None
            ]:
                Generator of tuple of
                    - layer name
                    - a tuple of
                        - layer itself
                        - inputs and outputs cache of each module in the layer
                        - layer input arguments
        """
        if not isinstance(model, DiffusionModelStruct):
            model_struct = DiffusionModelStruct.construct(model)
        else:
            model_struct = model
            model = model_struct.module
        assert isinstance(model_struct, DiffusionModelStruct)
        assert isinstance(model, nn.Module)
        action = DiffusionConcatCacheAction("cpu") if action is None else action
        layers, layer_structs, recomputes, use_prev_layer_outputs = model_struct.get_iter_layer_activations_args(
            skip_pre_modules=skip_pre_modules,
            skip_post_modules=skip_post_modules,
            **self.dataset[0]["input_kwargs"],
        )
        for layer_idx, (layer_name, (layer, layer_cache, layer_inputs)) in enumerate(
            self._iter_layer_activations(
                model,
                *args,
                action=action,
                layers=layers,
                needs_inputs_fn=needs_inputs_fn,
                needs_outputs_fn=needs_outputs_fn,
                recomputes=recomputes,
                use_prev_layer_outputs=use_prev_layer_outputs,
                **kwargs,
            )
        ):
            layer_kwargs = {k: v for k, v in layer_inputs[0].kwargs.items()}  # noqa: C416
            layer_kwargs.pop("hidden_states", None)
            layer_kwargs.pop("encoder_hidden_states", None)
            layer_kwargs.pop("temb", None)
            layer_struct = layer_structs[layer_idx]
            if isinstance(layer_struct, DiffusionBlockStruct):
                assert layer_struct.name == layer_name
                assert layer is layer_struct.module
                for transformer_block_struct in layer_struct.iter_transformer_block_structs():
                    for attn_struct in transformer_block_struct.iter_attention_structs():
                        if attn_struct.q_proj_name in layer_cache:
                            if not attn_struct.is_cross_attn():
                                cache = layer_cache[attn_struct.q_proj_name]
                                layer_cache[attn_struct.k_proj_name] = cache
                                layer_cache[attn_struct.v_proj_name] = cache
                        if attn_struct.add_k_proj_name in layer_cache:
                            assert not attn_struct.is_self_attn()
                            cache = layer_cache[attn_struct.add_k_proj_name]
                            layer_cache[attn_struct.add_v_proj_name] = cache
                            if attn_struct.is_joint_attn():
                                layer_cache[attn_struct.add_q_proj_name] = cache
                    ffn_struct = transformer_block_struct.ffn_struct
                    num_experts = ffn_struct.config.num_experts
                    if ffn_struct is not None and num_experts > 1:
                        for expert_idx in range(num_experts):
                            if ffn_struct.up_proj_names[expert_idx] in layer_cache:
                                cache = layer_cache[ffn_struct.up_proj_names[expert_idx]]
                                for up_proj_name in ffn_struct.up_proj_names[expert_idx::num_experts]:
                                    layer_cache[up_proj_name] = cache
                            if ffn_struct.down_proj_names[expert_idx] in layer_cache:
                                cache = layer_cache[ffn_struct.down_proj_names[expert_idx]]
                                for down_proj_name in ffn_struct.down_proj_names[expert_idx::num_experts]:
                                    layer_cache[down_proj_name] = cache
            yield layer_name, (layer_struct, layer_cache, layer_kwargs)


================================================
FILE: deepcompressor/app/diffusion/dataset/collect/calib.py
================================================
# -*- coding: utf-8 -*-
"""Collect calibration dataset."""

import os
from dataclasses import dataclass

import datasets
import torch
from omniconfig import configclass
from torch import nn
from tqdm import tqdm

from deepcompressor.app.diffusion.config import DiffusionPtqRunConfig
from deepcompressor.utils.common import hash_str_to_int, tree_map

from ...utils import get_control
from ..data import get_dataset
from .utils import CollectHook


def process(x: torch.Tensor) -> torch.Tensor:
    dtype = x.dtype
    return torch.from_numpy(x.float().numpy()).to(dtype)


def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
    samples_dirpath = os.path.join(config.output.root, "samples")
    caches_dirpath = os.path.join(config.output.root, "caches")
    os.makedirs(samples_dirpath, exist_ok=True)
    os.makedirs(caches_dirpath, exist_ok=True)
    caches = []

    pipeline = config.pipeline.build()
    model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
    assert isinstance(model, nn.Module)
    model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)

    batch_size = config.eval.batch_size
    print(f"In total {len(dataset)} samples")
    print(f"Evaluating with batch size {batch_size}")
    pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
    for batch in tqdm(
        dataset.iter(batch_size=batch_size, drop_last_batch=False),
        desc="Data",
        leave=False,
        dynamic_ncols=True,
        total=(len(dataset) + batch_size - 1) // batch_size,
    ):
        filenames = batch["filename"]
        prompts = batch["prompt"]
        seeds = [hash_str_to_int(name) for name in filenames]
        generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
        pipeline_kwargs = config.eval.get_pipeline_kwargs()

        task = config.pipeline.task
        control_root = config.eval.control_root
        if task in ["canny-to-image", "depth-to-image", "inpainting"]:
            controls = get_control(
                task,
                batch["image"],
                names=batch["filename"],
                data_root=os.path.join(
                    control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
                ),
            )
            if task == "inpainting":
                pipeline_kwargs["image"] = controls[0]
                pipeline_kwargs["mask_image"] = controls[1]
            else:
                pipeline_kwargs["control_image"] = controls

        result_images = pipeline(prompts, generator=generators, **pipeline_kwargs).images
        num_guidances = (len(caches) // batch_size) // config.eval.num_steps
        num_steps = len(caches) // (batch_size * num_guidances)
        assert (
            len(caches) == batch_size * num_steps * num_guidances
        ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
        for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
            image.save(os.path.join(samples_dirpath, f"{filename}.png"))
            for s in range(num_steps):
                for g in range(num_guidances):
                    c = caches[s * batch_size * num_guidances + g * batch_size + j]
                    c["filename"] = filename
                    c["step"] = s
                    c["guidance"] = g
                    c = tree_map(lambda x: process(x), c)
                    torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
        caches.clear()


@configclass
@dataclass
class CollectConfig:
    """Configuration for collecting calibration dataset.

    Args:
        root (`str`, *optional*, defaults to `"datasets"`):
            Root directory to save the collected dataset.
        dataset_name (`str`, *optional*, defaults to `"qdiff"`):
            Name of the collected dataset.
        prompt_path (`str`, *optional*, defaults to `"prompts/qdiff.yaml"`):
            Path to the prompt file.
        num_samples (`int`, *optional*, defaults to `128`):
            Number of samples to collect.
    """

    root: str = "datasets"
    dataset_name: str = "qdiff"
    data_path: str = "prompts/qdiff.yaml"
    num_samples: int = 128


if __name__ == "__main__":
    parser = DiffusionPtqRunConfig.get_parser()
    parser.add_config(CollectConfig, scope="collect", prefix="collect")
    configs, _, unused_cfgs, unused_args, unknown_args = parser.parse_known_args()
    ptq_config, collect_config = configs[""], configs["collect"]
    assert isinstance(ptq_config, DiffusionPtqRunConfig)
    assert isinstance(collect_config, CollectConfig)
    if len(unused_cfgs) > 0:
        print(f"Warning: unused configurations {unused_cfgs}")
    if unused_args is not None:
        print(f"Warning: unused arguments {unused_args}")
    assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"

    collect_dirpath = os.path.join(
        collect_config.root,
        str(ptq_config.pipeline.dtype),
        ptq_config.pipeline.name,
        ptq_config.eval.protocol,
        collect_config.dataset_name,
        f"s{collect_config.num_samples}",
    )
    print(f"Saving caches to {collect_dirpath}")

    dataset = get_dataset(
        collect_config.data_path,
        max_dataset_size=collect_config.num_samples,
        return_gt=ptq_config.pipeline.task in ["canny-to-image"],
        repeat=1,
    )

    ptq_config.output.root = collect_dirpath
    os.makedirs(ptq_config.output.root, exist_ok=True)
    collect(ptq_config, dataset=dataset)


================================================
FILE: deepcompressor/app/diffusion/dataset/collect/utils.py
================================================
# -*- coding: utf-8 -*-
"""Common utilities for collecting data."""

import inspect
import typing as tp

import torch
import torch.nn as nn
from diffusers.models.transformers import (
    FluxTransformer2DModel,
    PixArtTransformer2DModel,
    SanaTransformer2DModel,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

from deepcompressor.utils.common import tree_map, tree_split

__all__ = ["CollectHook"]


class CollectHook:
    def __init__(self, caches: list[dict[str, tp.Any]] = None, zero_redundancy: bool = False) -> None:
        self.caches = [] if caches is None else caches
        self.zero_redundancy = zero_redundancy

    def __call__(
        self,
        module: nn.Module,
        input_args: tuple[torch.Tensor, ...],
        input_kwargs: dict[str, tp.Any],
        output: tuple[torch.Tensor, ...],
    ) -> tp.Any:
        new_args = []
        signature = inspect.signature(module.forward)
        bound_arguments = signature.bind(*input_args, **input_kwargs)
        arguments = bound_arguments.arguments
        args_to_kwargs = {k: v for k, v in arguments.items() if k not in input_kwargs}
        input_kwargs.update(args_to_kwargs)

        if isinstance(module, UNet2DConditionModel):
            sample = input_kwargs.pop("sample")
            new_args.append(sample)
            timestep = input_kwargs["timestep"]
            timesteps = timestep
            if not torch.is_tensor(timesteps):
                is_mps = sample.device.type == "mps"
                if isinstance(timestep, float):
                    dtype = torch.float32 if is_mps else torch.float64
                else:
                    dtype = torch.int32 if is_mps else torch.int64
                timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
            elif len(timesteps.shape) == 0:
                timesteps = timesteps[None].to(sample.device)
            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
            timesteps = timesteps.expand(sample.shape[0])
            input_kwargs["timestep"] = timesteps
        elif isinstance(module, (PixArtTransformer2DModel, SanaTransformer2DModel)):
            new_args.append(input_kwargs.pop("hidden_states"))
        elif isinstance(module, FluxTransformer2DModel):
            new_args.append(input_kwargs.pop("hidden_states"))
        else:
            raise ValueError(f"Unknown model: {module}")
        cache = tree_map(lambda x: x.cpu(), {"input_args": new_args, "input_kwargs": input_kwargs, "outputs": output})
        split_cache = tree_split(cache)

        if isinstance(module, PixArtTransformer2DModel) and self.zero_redundancy:
            for cache in split_cache:
                cache_kwargs = cache["input_kwargs"]
                encoder_hidden_states = cache_kwargs.pop("encoder_hidden_states")
                assert encoder_hidden_states.shape[0] == 1
                encoder_attention_mask = cache_kwargs.get("encoder_attention_mask", None)
                if encoder_attention_mask is not None:
                    encoder_hidden_states = encoder_hidden_states[:, : max(encoder_attention_mask.sum(), 1)]
                cache_kwargs["encoder_hidden_states"] = encoder_hidden_states

        self.caches.extend(split_cache)


================================================
FILE: deepcompressor/app/diffusion/dataset/data/COCO/COCO.py
================================================
# coding=utf-8
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""COCO"""

import json
import os
import random
from pathlib import Path

import datasets
from PIL import Image

_CITATION = """
@article{DBLP:journals/corr/LinMBHPRDZ14,
  author    = {Tsung{-}Yi Lin and
               Michael Maire and
               Serge J. Belongie and
               Lubomir D. Bourdev and
               Ross B. Girshick and
               James Hays and
               Pietro Perona and
               Deva Ramanan and
               Piotr Doll{\'{a}}r and
               C. Lawrence Zitnick},
  title     = {Microsoft {COCO:} Common Objects in Context},
  journal   = {CoRR},
  volume    = {abs/1405.0312},
  year      = {2014},
  url       = {http://arxiv.org/abs/1405.0312},
  eprinttype = {arXiv},
  eprint    = {1405.0312},
  timestamp = {Mon, 13 Aug 2018 16:48:13 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/LinMBHPRDZ14.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""

_DESCRIPTION = """
MS COCO is a large-scale object detection, segmentation, and captioning dataset.
 COCO has several features: Object segmentation, Recognition in context, Superpixel stuff segmentation,
 330K images (>200K labeled), 1.5 million object instances, 80 object categories, 91 stuff categories,
 5 captions per image, 250,000 people with keypoints.
"""

_HOMEPAGE = "https://cocodataset.org/#home"

_LICENSE = "CC BY 4.0"


_IMAGES_URLS = {
    "train": "http://images.cocodataset.org/zips/train2014.zip",
    "validation": "http://images.cocodataset.org/zips/val2014.zip",
}

_KARPATHY_FILES_URL = "https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip"

_FEATURES = datasets.Features(
    {
        "filepath": datasets.Value("string"),
        "filename": datasets.Value("string"),
        "image": datasets.Image(),
        "image_path": datasets.Value("string"),
        "image_root": datasets.Value("string"),
        "prompt": datasets.Value("string"),
        "prompt_id": datasets.Value("int32"),
        "imgid": datasets.Value("int32"),
        "split": datasets.Value("string"),
        "cocoid": datasets.Value("int32"),
        "sentences_raw": [datasets.Value("string")],
        "sentids": [datasets.Value("int32")],
        "sentences_sentid": [datasets.Value("int32")],
        "sentences_tokens": [[datasets.Value("string")]],
    }
)


def hash_string_to_int(s: str) -> int:
    modulus = 10**9 + 7  # Large prime modulus
    hash_int = 0
    for char in s:
        hash_int = (hash_int * 31 + ord(char)) % modulus
    return hash_int


class COCOConfig(datasets.BuilderConfig):
    def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
        super(COCOConfig, self).__init__(
            name=kwargs.get("name", "default"),
            version=kwargs.get("version", "0.0.0"),
            data_dir=kwargs.get("data_dir", None),
            data_files=kwargs.get("data_files", None),
            description=kwargs.get("description", None),
        )
        self.max_dataset_size = max_dataset_size
        self.return_gt = return_gt


class COCO(datasets.GeneratorBasedBuilder):
    """COCO"""

    VERSION = datasets.Version("0.0.0")

    BUILDER_CONFIG_CLASS = COCOConfig
    BUILDER_CONFIGS = [
        COCOConfig(name="COCO_val", version=VERSION, description="COCO validation prompt set"),
        COCOConfig(name="COCO_train", version=VERSION, description="COCO train prompt set"),
        COCOConfig(name="COCO_full", version=VERSION, description="COCO full prompt set"),
    ]
    DEFAULT_CONFIG_NAME = "COCO_val"

    def _info(self):
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=_FEATURES,
            homepage=_HOMEPAGE,
            license=_LICENSE,
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager: datasets.download.DownloadManager):
        annotation_file = os.path.join(dl_manager.download_and_extract(_KARPATHY_FILES_URL), "dataset_coco.json")
        image_folders = {k: Path(v) for k, v in dl_manager.download_and_extract(_IMAGES_URLS).items()}

        if self.config.name == "COCO_full":
            split_keys = ["validation", "train"]
        else:
            split_keys = [self.config.name.split("_")[-1]]

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                gen_kwargs={
                    "annotation_file": annotation_file,
                    "image_folders": image_folders,
                    "split_keys": split_keys,
                },
            ),
        ]

    def _generate_examples(
        self, annotation_file: str, image_folders: dict[str, str], split_keys: list[str] | tuple[str, ...]
    ):
        with open(annotation_file, "r", encoding="utf-8") as fi:
            annotations = json.load(fi)
        metas = []
        for split_key in split_keys:
            for image_metadata in annotations["images"]:
                if split_key == "train":
                    if image_metadata["split"] != "train" and image_metadata["split"] != "restval":
                        continue
                elif split_key == "val":
                    if image_metadata["split"] != "val":
                        continue
                elif split_key == "test":
                    if image_metadata["split"] != "test":
                        continue

                metas.append(image_metadata)

        if self.config.max_dataset_size > 0:
            random.Random(0).shuffle(metas)
            metas = metas[: self.config.max_dataset_size]
            metas = sorted(metas, key=lambda x: x["filename"])

        for i, meta in enumerate(metas):
            if "val2014" in meta["filename"]:
                image_root = os.path.join(image_folders["validation"], "val2014")
            else:
                image_root = os.path.join(image_folders["train"], "train2014")
            filename = meta["filename"].replace(".jpg", "").replace(".png", "")
            image_path = os.path.join(image_root, filename + ".jpg")

            sentences_raw = [caption["raw"] for caption in meta["sentences"]]
            prompt_id = hash_string_to_int(filename) % len(sentences_raw)
            prompt = sentences_raw[prompt_id]

            yield (
                i,
                {
                    "filename": filename,
                    "image": Image.open(image_path) if self.config.return_gt else None,
                    "image_path": image_path,
                    "image_root": image_root,
                    "prompt": prompt,
                    "prompt_id": prompt_id,
                    "imgid": meta["imgid"],
                    "split": self.config.name,
                    "coco_id": meta["cocoid"],
                    "sentences_raw": sentences_raw,
                    "sentids": meta["sentids"],
                    "sentences_sentid": [caption["sentid"] for caption in meta["sentences"]],
                    "sentences_tokens": [caption["tokens"] for caption in meta["sentences"]],
                },
            )


================================================
FILE: deepcompressor/app/diffusion/dataset/data/COCO/__init__.py
================================================


================================================
FILE: deepcompressor/app/diffusion/dataset/data/DCI/DCI.py
================================================
import os
import random

import datasets
import yaml
from PIL import Image

_CITATION = """\
@InProceedings{Urbanek_2024_CVPR,
    author    = {Urbanek, Jack and Bordes, Florian and Astolfi, Pietro and Williamson, Mary and Sharma, Vasu and Romero-Soriano, Adriana},
    title     = {A Picture is Worth More Than 77 Text Tokens: Evaluating CLIP-Style Models on Dense Captions},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2024},
    pages     = {26700-26709}
}
"""  # noqa: E501

_DESCRIPTION = """\
The Densely Captioned Images dataset, or DCI, consists of 7805 images from SA-1B,
 each with a complete description aiming to capture the full visual detail of what is present in the image.
 Much of the description is directly aligned to submasks of the image.
"""

_HOMEPAGE = "https://github.com/facebookresearch/DCI"

_LICENSE = "Attribution-NonCommercial 4.0 International (https://github.com/facebookresearch/DCI/blob/main/LICENSE)"

IMAGE_URL = "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.gz"

PROMPT_URLS = {"sDCI": "https://huggingface.co/datasets/mit-han-lab/svdquant-datasets/resolve/main/sDCI.yaml"}


class DCIConfig(datasets.BuilderConfig):
    def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
        super(DCIConfig, self).__init__(
            name=kwargs.get("name", "default"),
            version=kwargs.get("version", "0.0.0"),
            data_dir=kwargs.get("data_dir", None),
            data_files=kwargs.get("data_files", None),
            description=kwargs.get("description", None),
        )
        self.max_dataset_size = max_dataset_size
        self.return_gt = return_gt


class DCI(datasets.GeneratorBasedBuilder):
    VERSION = datasets.Version("0.0.0")

    BUILDER_CONFIG_CLASS = DCIConfig
    BUILDER_CONFIGS = [DCIConfig(name="sDCI", version=VERSION, description="sDCI full prompt set")]
    DEFAULT_CONFIG_NAME = "sDCI"

    def _info(self):
        features = datasets.Features(
            {
                "filename": datasets.Value("string"),
                "image": datasets.Image(),
                "prompt": datasets.Value("string"),
                "meta_path": datasets.Value("string"),
                "image_root": datasets.Value("string"),
                "image_path": datasets.Value("string"),
                "split": datasets.Value("string"),
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
        )

    def _split_generators(self, dl_manager: datasets.download.DownloadManager):
        image_url = IMAGE_URL
        meta_url = PROMPT_URLS[self.config.name]

        meta_path = dl_manager.download(meta_url)
        image_root = dl_manager.download_and_extract(image_url)

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
            )
        ]

    def _generate_examples(self, meta_path: str, image_root: str):
        meta = yaml.safe_load(open(meta_path, "r"))
        names = list(meta.keys())
        if self.config.max_dataset_size > 0:
            random.Random(0).shuffle(names)
            names = names[: self.config.max_dataset_size]
            names = sorted(names)

        for i, name in enumerate(names):
            prompt = meta[name]
            image_path = os.path.join(image_root, f"{name}.jpg")
            yield (
                i,
                {
                    "filename": name,
                    "image": Image.open(image_path) if self.config.return_gt else None,
                    "prompt": prompt,
                    "meta_path": meta_path,
                    "image_root": image_root,
                    "image_path": image_path,
                    "split": self.config.name,
                },
            )


================================================
FILE: deepcompressor/app/diffusion/dataset/data/DCI/__init__.py
================================================


================================================
FILE: deepcompressor/app/diffusion/dataset/data/MJHQ/MJHQ.py
================================================
import json
import os
import random

import datasets
from PIL import Image

_CITATION = """\
@misc{li2024playground,
      title={Playground v2.5: Three Insights towards Enhancing Aesthetic Quality in Text-to-Image Generation},
      author={Daiqing Li and Aleks Kamko and Ehsan Akhgari and Ali Sabet and Linmiao Xu and Suhail Doshi},
      year={2024},
      eprint={2402.17245},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
"""

_DESCRIPTION = """\
We introduce a new benchmark, MJHQ-30K, for automatic evaluation of a model’s aesthetic quality.
 The benchmark computes FID on a high-quality dataset to gauge aesthetic quality.
"""

_HOMEPAGE = "https://huggingface.co/datasets/playgroundai/MJHQ-30K"

_LICENSE = (
    "Playground v2.5 Community License "
    "(https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md)"
)

IMAGE_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/mjhq30k_imgs.zip"

META_URL = "https://huggingface.co/datasets/playgroundai/MJHQ-30K/resolve/main/meta_data.json"


class MJHQConfig(datasets.BuilderConfig):
    def __init__(self, max_dataset_size: int = -1, return_gt: bool = False, **kwargs):
        super(MJHQConfig, self).__init__(
            name=kwargs.get("name", "default"),
            version=kwargs.get("version", "0.0.0"),
            data_dir=kwargs.get("data_dir", None),
            data_files=kwargs.get("data_files", None),
            description=kwargs.get("description", None),
        )
        self.max_dataset_size = max_dataset_size
        self.return_gt = return_gt


class DCI(datasets.GeneratorBasedBuilder):
    VERSION = datasets.Version("0.0.0")

    BUILDER_CONFIG_CLASS = MJHQConfig
    BUILDER_CONFIGS = [MJHQConfig(name="MJHQ", version=VERSION, description="MJHQ-30K full dataset")]
    DEFAULT_CONFIG_NAME = "MJHQ"

    def _info(self):
        features = datasets.Features(
            {
                "filename": datasets.Value("string"),
                "category": datasets.Value("string"),
                "image": datasets.Image(),
                "prompt": datasets.Value("string"),
                "prompt_path": datasets.Value("string"),
                "image_root": datasets.Value("string"),
                "image_path": datasets.Value("string"),
                "split": datasets.Value("string"),
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION, features=features, homepage=_HOMEPAGE, license=_LICENSE, citation=_CITATION
        )

    def _split_generators(self, dl_manager: datasets.download.DownloadManager):
        meta_path = dl_manager.download(META_URL)
        image_root = dl_manager.download_and_extract(IMAGE_URL)
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN, gen_kwargs={"meta_path": meta_path, "image_root": image_root}
            ),
        ]

    def _generate_examples(self, meta_path: str, image_root: str):
        with open(meta_path, "r") as f:
            meta = json.load(f)

        names = list(meta.keys())
        if self.config.max_dataset_size > 0:
            random.Random(0).shuffle(names)
            names = names[: self.config.max_dataset_size]
            names = sorted(names)

        for i, name in enumerate(names):
            category = meta[name]["category"]
            prompt = meta[name]["prompt"]
            image_path = os.path.join(image_root, category, f"{name}.jpg")
            yield (
                i,
                {
                    "filename": name,
                    "category": category,
                    "image": Image.open(image_path) if self.config.return_gt else None,
                    "prompt": prompt,
                    "meta_path": meta_path,
                    "image_root": image_root,
                    "image_path": image_path,
                    "split": self.config.name,
                },
            )


================================================
FILE: deepcompressor/app/diffusion/dataset/data/MJHQ/__init__.py
================================================


================================================
FILE: deepcompressor/app/diffusion/dataset/data/__init__.py
================================================
import os
import random

import datasets
import yaml

__all__ = ["get_dataset"]


def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat: int = 4) -> dict:
    meta = yaml.safe_load(open(meta_path, "r"))
    names = list(meta.keys())
    if max_dataset_size > 0:
        random.Random(0).shuffle(names)
        names = names[:max_dataset_size]
        names = sorted(names)

    ret = {"filename": [], "prompt": [], "meta_path": []}
    idx = 0
    for name in names:
        prompt = meta[name]
        for j in range(repeat):
            ret["filename"].append(f"{name}-{j}")
            ret["prompt"].append(prompt)
            ret["meta_path"].append(meta_path)
            idx += 1
    return ret


def get_dataset(
    name: str,
    config_name: str | None = None,
    split: str = "train",
    max_dataset_size: int = -1,
    return_gt: bool = False,
    repeat: int = 4,
    chunk_start: int = 0,
    chunk_step: int = 1,
) -> datasets.Dataset:
    prefix = os.path.dirname(__file__)
    kwargs = {
        "name": config_name,
        "split": split,
        "trust_remote_code": True,
        "token": True,
        "max_dataset_size": max_dataset_size,
    }
    if name.endswith((".yaml", ".yml")):
        dataset = datasets.Dataset.from_dict(
            load_dataset_yaml(name, max_dataset_size=max_dataset_size, repeat=repeat),
            features=datasets.Features(
                {
                    "filename": datasets.Value("string"),
                    "prompt": datasets.Value("string"),
                    "meta_path": datasets.Value("string"),
                }
            ),
        )
    else:
        path = os.path.join(prefix, f"{name}")
        if name == "COCO":
            dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
        elif name == "DCI":
            dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
        elif name == "MJHQ":
            dataset = datasets.load_dataset(path, return_gt=return_gt, **kwargs)
        else:
            raise ValueError(f"Unknown dataset name: {name}")
    assert not hasattr(dataset, "_unchunk_size")
    assert not hasattr(dataset, "_chunk_start")
    assert not hasattr(dataset, "_chunk_step")
    unchunk_size = len(dataset)
    if chunk_step > 1 or chunk_start > 0:
        assert 0 <= chunk_start < chunk_step
        dataset = dataset.select(range(chunk_start, len(dataset), chunk_step))
    else:
        chunk_start, chunk_step = 0, 1
    dataset._unchunk_size = unchunk_size
    dataset._chunk_start = chunk_start
    dataset._chunk_step = chunk_step
    return dataset


================================================
FILE: deepcompressor/app/diffusion/dataset/data/dump.py
================================================
import argparse
import os

import yaml
from tqdm import tqdm

from ...utils import get_control
from . import get_dataset

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmarks", type=str, nargs="*", default=["COCO", "DCI", "MJHQ"])
    parser.add_argument("--max-dataset-size", type=int, default=-1)
    parser.add_argument("--dump-root", type=str, default="benchmarks")
    parser.add_argument("--copy-images", action="store_true")
    parser.add_argument("--prompts-only", action="store_true")
    parser.add_argument("--controls", type=str, nargs="*", default=["canny-to-image", "depth-to-image", "inpainting"])
    parser.add_argument("--chunk-start", type=int, default=0)
    parser.add_argument("--chunk-step", type=int, default=1)
    args = parser.parse_args()

    if "depth-to-image" in args.controls:
        from image_gen_aux import DepthPreprocessor

        processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf").to("cuda")

    for benchmark in args.benchmarks:
        dataset = get_dataset(
            benchmark,
            max_dataset_size=args.max_dataset_size,
            return_gt=True,
            chunk_start=args.chunk_start,
            chunk_step=args.chunk_step,
        )
        prompts = {}
        benchmark_root = os.path.join(args.dump_root, benchmark, f"{dataset.config_name}-{dataset._unchunk_size}")
        for row in tqdm(dataset, desc=f"Dumping {dataset.config_name}"):
            prompts[row["filename"]] = row["prompt"]
            if not args.prompts_only:
                image = row.get("image", None)
                if image is not None:
                    image_root = os.path.join(benchmark_root, "images")
                    os.makedirs(image_root, exist_ok=True)
                    if args.copy_images:
                        image.save(os.path.join(image_root, row["filename"] + ".png"))
                    else:
                        ext = os.path.basename(row["image_path"]).split(".")[-1]
                        os.symlink(
                            os.path.abspath(os.path.expanduser(row["image_path"])),
                            os.path.abspath(os.path.expanduser(os.path.join(image_root, row["filename"] + f".{ext}"))),
                        )
                    if "canny-to-image" in args.controls:
                        canny_root = os.path.join(benchmark_root, "canny_images")
                        os.makedirs(canny_root, exist_ok=True)
                        canny = get_control("canny-to-image", image)
                        canny.save(os.path.join(canny_root, row["filename"] + ".png"))
                    if "depth-to-image" in args.controls:
                        depth_root = os.path.join(benchmark_root, "depth_images")
                        os.makedirs(depth_root, exist_ok=True)
                        depth = get_control("depth-to-image", image, processor=processor)
                        depth.save(os.path.join(depth_root, row["filename"] + ".png"))
                    if "inpainting" in args.controls:
                        mask_root = os.path.join(benchmark_root, "mask_images")
                        cropped_image_root = os.path.join(benchmark_root, "cropped_images")
                        os.makedirs(mask_root, exist_ok=True)
                        os.makedirs(cropped_image_root, exist_ok=True)
                        cropped_image, mask_image = get_control("inpainting", image, names=row["filename"])
                        cropped_image.save(os.path.join(cropped_image_root, row["filename"] + ".png"))
                        mask_image.save(os.path.join(mask_root, row["filename"] + ".png"))

        if args.chunk_step == 1:
            os.makedirs(benchmark_root, exist_ok=True)
            with open(os.path.join(benchmark_root, "prompts.yaml"), "w") as f:
                yaml.dump(prompts, f)


================================================
FILE: deepcompressor/app/diffusion/eval/__init__.py
================================================
# -*- coding: utf-8 -*-

from .config import DiffusionEvalConfig


================================================
FILE: deepcompressor/app/diffusion/eval/config.py
================================================
# -*- coding: utf-8 -*-
"""Diffusion model evaluation."""

import logging
import os
import typing as tp
from dataclasses import dataclass, field

import datasets
import diffusers
import omniconfig
import torch
from diffusers import DiffusionPipeline
from omniconfig import configclass
from torch import multiprocessing as mp
from tqdm import tqdm

from deepcompressor.app.diffusion.dataset.data import get_dataset
from deepcompressor.utils.common import hash_str_to_int

from ..utils import get_control
from .metrics import compute_image_metrics

__all__ = ["DiffusionEvalConfig"]


@configclass
@dataclass
class DiffusionEvalConfig:
    """Diffusion model evaluation configuration.

    Args:
        protocol (`str`):
            The protocol of the evaluation pipeline.
        num_gpus (`int`, *optional*, defaults to `1`):
            The number of GPUs to use.
        batch_size (`int`, *optional*, defaults to `1`):
            The batch size used for inference.
        batch_size_per_gpu (`int`, *optional*, defaults to `None`):
            The batch size per GPU.
        height (`int`, *optional*, defaults to `None`):
            The height of the generated images.
        width (`int`, *optional*, defaults to `None`):
            The width of the generated images.
        clean_caption (`bool`, *optional*, defaults to `None`):
            Whether to clean the caption.
        num_steps (`int`, *optional*, defaults to `None`):
            The number of inference steps.
        guidance_scale (`float`, *optional*, defaults to `None`):
            The guidance scale.
        num_samples (`int`, *optional*, defaults to `1024`):
            The number of samples to generate.
        benchmarks (`list[str]`, *optional*, defaults to `["COCO", "DCI", "MJHQ", "GenEval"]`):
            The benchmark datasets to evaluate on.
        gt_metrics (`list[str]`, *optional*, defaults to `["clip_iqa", "clip_score", "psnr", "lpips", "ssim", "fid"]`):
            The ground truth metrics to compute.
        ref_metrics (`list[str]`, *optional*, defaults to `["psnr", "lpips", "ssim", "fid"]`):
            The reference metrics to compute.
        ref_root (`str`, *optional*, defaults to `""`):
            The root directory path to the reference images.
        gt_stats_root (`str`, *optional*, defaults to `""`):
            The root directory path to the ground truth statistics.
        chunk_start (`int`, *optional*, defaults to `0`):
            The starting chunk index.
        chunk_step (`int`, *optional*, defaults to `1`):
            The chunk step size.
    """

    protocol: str

    num_gpus: int = field(default=1, metadata={omniconfig.ARGPARSE_ARGS: ("--num-gpus", "-n")})
    batch_size: int = 1
    batch_size_per_gpu: int | None = None

    height: int | None = None
    width: int | None = None
    clean_caption: bool | None = None
    num_steps: int | None = None
    guidance_scale: float | None = None
    num_samples: int = 1024

    benchmarks: list[str] = field(
        default_factory=lambda: ["COCO", "DCI", "MJHQ", "GenEval"],
        metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
    )
    gt_metrics: list[str] = field(
        default_factory=lambda: ["clip_iqa", "clip_score", "image_reward", "fid"],
        metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
    )
    ref_metrics: list[str] = field(
        default_factory=lambda: ["psnr", "lpips", "ssim"],
        metadata={omniconfig.ARGPARSE_KWARGS: {"nargs": "+", "type": str}},
    )
    gen_root: str = ""
    ref_root: str = ""
    gt_stats_root: str = ""
    control_root: str | None = None

    chunk_start: int = 0
    chunk_step: int = 1
    chunk_only: bool = False

    def __post_init__(self):
        assert self.protocol
        self.protocol = self.protocol.lower().format(num_steps=self.num_steps, guidance_scale=self.guidance_scale)
        assert 0 <= self.chunk_start < self.chunk_step
        if self.chunk_start == 0 and self.chunk_step == 1:
            self.chunk_only = False

    def get_pipeline_kwargs(self) -> dict[str, tp.Any]:
        kwargs = {}
        if self.height is not None:
            kwargs["height"] = self.height
        if self.width is not None:
            kwargs["width"] = self.width
        if self.clean_caption is not None:
            kwargs["clean_caption"] = self.clean_caption
        if self.num_steps is not None:
            kwargs["num_inference_steps"] = self.num_steps
        if self.guidance_scale is not None:
            kwargs["guidance_scale"] = self.guidance_scale
        return kwargs

    def _generate(
        self,
        rank: int,
        dataset: datasets.Dataset,
        pipeline: DiffusionPipeline,
        dirpath: str,
        logger: logging.Logger,
        dataset_name: str | None = None,
        task: str = "text-to-image",
        control_root: str | None = None,
    ) -> None:
        if self.num_gpus > 1:
            pipeline = pipeline.to(rank)
        if rank == 0:
            logger.info(
                f"  {dataset.config_name} has {len(dataset)} samples "
                f"(chunk_start={dataset._chunk_start}, chunk_step={dataset._chunk_step},"
                f" unchunk_size={dataset._unchunk_size})"
            )
        pipeline.set_progress_bar_config(
            desc="Sampling",
            leave=False,
            dynamic_ncols=True,
            position=1,
            disable=self.num_gpus > 1,
        )
        if dataset_name is None:
            dataset_name = dataset.config_name
        for batch in tqdm(
            dataset.iter(batch_size=self.batch_size, drop_last_batch=False),
            desc=dataset_name if self.num_gpus == 1 else f"{dataset_name} (GPU {rank})",
            leave=False,
            dynamic_ncols=True,
            position=rank,
            total=(len(dataset) + self.batch_size - 1) // self.batch_size,
        ):
            filenames = batch["filename"][rank :: self.num_gpus]
            if len(filenames) == 0:
                continue
            if all(os.path.exists(os.path.join(dirpath, f"{filename}.png")) for filename in filenames):
                continue
            prompts = batch["prompt"][rank :: self.num_gpus]
            seeds = [hash_str_to_int(name) for name in filenames]
            diffusers.training_utils.set_seed(seeds[0])
            generators = [torch.Generator().manual_seed(seed) for seed in seeds]

            pipeline_kwargs = self.get_pipeline_kwargs()

            if task in ["canny-to-image", "depth-to-image", "inpainting"]:
                controls = get_control(
                    task,
                    batch["image"],
                    names=batch["filename"],
                    data_root=os.path.join(control_root, f"{dataset_name}-{dataset._unchunk_size}"),
                )
                if task == "inpainting":
                    pipeline_kwargs["image"] = controls[0]
                    pipeline_kwargs["mask_image"] = controls[1]
                else:
                    pipeline_kwargs["control_image"] = controls

            output = pipeline(prompts, generator=generators, **pipeline_kwargs)
            images = output.images
            for filename, image in zip(filenames, images, strict=True):
                image.save(os.path.join(dirpath, f"{filename}.png"))

    def generate(
        self,
        pipeline: DiffusionPipeline,
        gen_root: str = "",
        task: str = "text-to-image",
    ) -> None:
        logger = logging.getLogger(f"{__name__}.DiffusionEval")
        gen_root = gen_root or self.gen_root
        for benchmark in self.benchmarks:
            dataset = get_dataset(
                benchmark,
                max_dataset_size=self.num_samples,
                chunk_start=self.chunk_start,
                chunk_step=self.chunk_step,
                return_gt=task in ["canny-to-image"],
                repeat=1,
            )
            if benchmark.endswith(".yaml") or benchmark.endswith(".yml"):
                dataset_name = os.path.splitext(os.path.basename(benchmark))[0]
                dirpath = os.path.join(
                    gen_root,
                    "samples",
                    "YAML",
                    f"{dataset_name}-{dataset._unchunk_size}",
                )
            else:
                dataset_name = dataset.config_name
                dirpath = os.path.join(
                    gen_root,
                    "samples",
                    benchmark,
                    f"{dataset.config_name}-{dataset._unchunk_size}",
                )
            if self.chunk_only:
                dirpath += f".{dataset._chunk_start}.{dataset._chunk_step}"
            os.makedirs(dirpath, exist_ok=True)
            args = (dataset, pipeline, dirpath, logger, dataset_name, task, os.path.join(self.control_root, benchmark))
            if self.num_gpus == 1:
                self._generate(0, *args)
            else:
                mp.spawn(self._generate, args=args, nprocs=self.num_gpus, join=True)

    def evaluate(
        self, pipeline: DiffusionPipeline, gen_root: str = "", skip_gen: bool = False, task: str = "text-to-image"
    ) -> dict[str, tp.Any] | None:
        gen_root = gen_root or self.gen_root
        if not skip_gen:
            self.generate(pipeline, gen_root=gen_root, task=task)
        if not self.chunk_only:
            return compute_image_metrics(
                gen_root=gen_root,
                benchmarks=self.benchmarks,
                max_dataset_size=self.num_samples,
                chunk_start=self.chunk_start,
                chunk_step=self.chunk_step,
                ref_root=self.ref_root,
                gt_stats_root=self.gt_stats_root,
                gt_metrics=self.gt_metrics,
                ref_metrics=self.ref_metrics,
            )
        else:
            return {}


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/__init__.py
================================================
import logging
import os

from deepcompressor.app.diffusion.dataset.data import get_dataset

from .fid import compute_fid
from .image_reward import compute_image_reward
from .multimodal import compute_image_multimodal_metrics
from .similarity import compute_image_similarity_metrics

logging.getLogger("PIL").setLevel(logging.WARNING)

__all__ = ["compute_image_metrics"]


def compute_image_metrics(
    gen_root: str,
    benchmarks: str | tuple[str, ...] = ("DCI", "GenAIBench", "GenEval", "MJHQ", "T2ICompBench"),
    max_dataset_size: int = -1,
    chunk_start: int = 0,
    chunk_step: int = 1,
    chunk_only: bool = False,
    ref_root: str = "",
    gt_stats_root: str = "",
    gt_metrics: tuple[str, ...] = ("clip_iqa", "clip_score", "image_reward", "fid"),
    ref_metrics: tuple[str, ...] = ("psnr", "lpips", "ssim", "fid"),
) -> dict:
    if chunk_start == 0 and chunk_step == 1:
        chunk_only = False
    assert chunk_start == 0 and chunk_step == 1, "Chunking is not supported for image data."
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    if isinstance(benchmarks, str):
        benchmarks = (benchmarks,)
    gt_multimodal_metrics, gt_similarity_metrics, gt_other_metrics = categorize_metrics(gt_metrics)
    _, ref_similarity_metrics, ref_other_metrics = categorize_metrics(ref_metrics)
    results = {}
    for benchmark in benchmarks:
        benchmark_results = {}
        dataset = get_dataset(benchmark, max_dataset_size=max_dataset_size, return_gt=True)
        dirname = f"{dataset.config_name}-{dataset._unchunk_size}"
        if dataset._chunk_start == 0 and dataset._chunk_step == 1:
            filename = f"{dirname}.npz"
        else:
            filename = os.path.join(dirname, f"{dataset._chunk_start}-{dataset._chunk_step}.npz")
            if chunk_only:
                dirname += f".{dataset._chunk_start}.{dataset._chunk_step}"
        gen_dirpath = os.path.join(gen_root, "samples", benchmark, dirname)
        if gt_metrics:
            gt_results = compute_image_multimodal_metrics(dataset, gen_dirpath, metrics=gt_multimodal_metrics)
            if "image_reward" in gt_other_metrics:
                gt_results.update(compute_image_reward(dataset, gen_dirpath))
            if benchmark in ("COCO", "DCI", "MJHQ"):
                gt_results.update(compute_image_similarity_metrics(dataset, gen_dirpath, metrics=gt_similarity_metrics))
                if "fid" in gt_other_metrics:
                    gt_results["fid"] = compute_fid(
                        dataset,
                        gen_dirpath,
                        ref_cache_path=(os.path.join(gt_stats_root, benchmark, filename) if gt_stats_root else None),
                        gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
                    )
            benchmark_results["with_gt"] = gt_results
        if ref_root and ref_metrics:
            assert os.path.exists(ref_root), f"Reference root directory {ref_root} does not exist."
            ref_dirpath = os.path.join(ref_root, "samples", benchmark, dirname)
            ref_results = compute_image_similarity_metrics(ref_dirpath, gen_dirpath, metrics=ref_similarity_metrics)
            if "fid" in ref_other_metrics:
                ref_results["fid"] = compute_fid(
                    ref_dirpath,
                    gen_dirpath,
                    ref_cache_path=os.path.join(ref_root, "fid_stats", benchmark, filename),
                    gen_cache_path=os.path.join(gen_root, "fid_stats", benchmark, filename),
                )
            benchmark_results["with_orig"] = ref_results
        print(f"{dirname} results:")
        print(benchmark_results)
        results[dirname] = benchmark_results
    return results


def categorize_metrics(metrics: tuple[str, ...]) -> tuple[list[str], list[str], list[str]]:
    """
    Categorize metrics into multimodal, similarity, and other metrics.

    Args:
        metrics (tuple[str, ...]): List of metrics.

    Returns:
        tuple[list[str], list[str], list[str]]: Tuple of multimodal, similarity, and other metrics.
    """
    metrics = tuple(set(metrics))
    multimodal_metrics, similarity_metrics, other_metrics = [], [], []
    for metric in metrics:
        if metric in ("clip_iqa", "clip_score"):
            multimodal_metrics.append(metric)
        elif metric in ("psnr", "lpips", "ssim"):
            similarity_metrics.append(metric)
        else:
            other_metrics.append(metric)
    return multimodal_metrics, similarity_metrics, other_metrics


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/fid.py
================================================
import os
from datetime import datetime

import numpy as np
import torch
import torchvision
from cleanfid import fid
from cleanfid.resize import build_resizer
from datasets import Dataset
from tqdm import tqdm

__all__ = ["compute_fid"]


def get_dataset_features(
    dataset: Dataset,
    model,
    mode: str = "clean",
    batch_size: int = 128,
    device: str | torch.device = "cuda",
) -> np.ndarray:
    to_tensor = torchvision.transforms.ToTensor()
    fn_resize = build_resizer(mode)
    np_feats = []
    for batch in tqdm(
        dataset.iter(batch_size=batch_size, drop_last_batch=False),
        desc=f"Extracting {dataset.config_name} features",
        total=(len(dataset) + batch_size - 1) // batch_size,
    ):
        resized_images = [fn_resize(np.array(image.convert("RGB"))) for image in batch["image"]]
        image_tensors = []
        for resized_image in resized_images:
            if resized_image.dtype == "uint8":
                image_tensor = to_tensor(resized_image) * 255
            else:
                image_tensor = to_tensor(resized_image)
            image_tensors.append(image_tensor)
        image_tensors = torch.stack(image_tensors, dim=0)
        np_feats.append(fid.get_batch_features(image_tensors, model, device))
    np_feats = np.concatenate(np_feats, axis=0)
    return np_feats


def get_fid_features(
    dataset_or_folder: str | Dataset | None = None,
    cache_path: str | None = None,
    num: int | None = None,
    mode: str = "clean",
    num_workers: int = 8,
    batch_size: int = 64,
    device: str | torch.device = "cuda",
    force_overwrite: bool = False,
    verbose: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
    if cache_path is not None and os.path.exists(cache_path) and not force_overwrite:
        npz = np.load(cache_path)
        mu, sigma = npz["mu"], npz["sigma"]
    else:
        feat_model = fid.build_feature_extractor(mode, device)
        if isinstance(dataset_or_folder, str):
            np_feats = fid.get_folder_features(
                dataset_or_folder,
                feat_model,
                num_workers=num_workers,
                num=num,
                batch_size=batch_size,
                device=device,
                verbose=verbose,
                mode=mode,
                description=f"Extracting {dataset_or_folder} features",
            )
        else:
            assert isinstance(dataset_or_folder, Dataset)
            np_feats = get_dataset_features(
                dataset_or_folder, model=feat_model, mode=mode, batch_size=batch_size, device=device
            )

        mu = np.mean(np_feats, axis=0)
        sigma = np.cov(np_feats, rowvar=False)
        if cache_path is not None:
            os.makedirs(os.path.abspath(os.path.dirname(cache_path)), exist_ok=True)
            np.savez(cache_path, mu=mu, sigma=sigma)

    return mu, sigma


def compute_fid(
    ref_dirpath_or_dataset: str | Dataset,
    gen_dirpath: str,
    ref_cache_path: str | None = None,
    gen_cache_path: str | None = None,
    use_symlink: bool = True,
    timestamp: str | None = None,
) -> float:
    sym_ref_dirpath, sym_gen_dirpath = None, None
    if use_symlink:
        if timestamp is None:
            timestamp = datetime.now().strftime("%y%m%d.%H%M%S")

        os.makedirs(".tmp", exist_ok=True)

        if isinstance(ref_dirpath_or_dataset, str):
            sym_ref_dirpath = os.path.join(".tmp", f"ref-{hash(str(ref_dirpath_or_dataset))}-{timestamp}")
            os.symlink(os.path.abspath(ref_dirpath_or_dataset), os.path.abspath(sym_ref_dirpath))
            ref_dirpath_or_dataset = sym_ref_dirpath

        sym_gen_dirpath = os.path.join(".tmp", f"gen-{hash(str(gen_dirpath))}-{timestamp}")
        os.symlink(os.path.abspath(gen_dirpath), os.path.abspath(sym_gen_dirpath))
        gen_dirpath = sym_gen_dirpath
    mu1, sigma1 = get_fid_features(dataset_or_folder=ref_dirpath_or_dataset, cache_path=ref_cache_path)
    mu2, sigma2 = get_fid_features(dataset_or_folder=gen_dirpath, cache_path=gen_cache_path)
    fid_score = fid.frechet_distance(mu1, sigma1, mu2, sigma2)
    fid_score = float(fid_score)
    if use_symlink:
        if sym_ref_dirpath is not None:
            os.remove(sym_ref_dirpath)
        os.remove(sym_gen_dirpath)
    return fid_score


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/image_reward.py
================================================
import os

import datasets
import torch
from tqdm import tqdm

__all__ = ["compute_image_reward"]


def compute_image_reward(
    ref_dataset: datasets.Dataset,
    gen_dirpath: str,
) -> dict[str, float]:
    # import here to remove dependency on `ImageReward` git repo
    import ImageReward as RM

    scores = []
    model = RM.load("ImageReward-v1.0")
    for batch in tqdm(
        ref_dataset.iter(batch_size=1, drop_last_batch=False),
        desc=f"{ref_dataset.config_name} image reward",
        total=len(ref_dataset),
        dynamic_ncols=True,
    ):
        filename = batch["filename"][0]
        path = os.path.join(gen_dirpath, f"{filename}.png")
        prompt = batch["prompt"][0]
        with torch.inference_mode():
            score = model.score(prompt, path)
        scores.append(score)
    result = {"image_reward": sum(scores) / len(scores)}
    return result


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/multimodal.py
================================================
import os

import datasets
import numpy as np
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.multimodal import CLIPImageQualityAssessment, CLIPScore
from tqdm import tqdm

__all__ = ["compute_image_multimodal_metrics"]


class PromptImageDataset(data.Dataset):
    def __init__(self, ref_dataset: datasets.Dataset, gen_dirpath: str):
        super(data.Dataset, self).__init__()
        self.ref_dataset, self.gen_dirpath = ref_dataset, gen_dirpath
        self.transform = torchvision.transforms.ToTensor()

    def __len__(self):
        return len(self.ref_dataset)

    def __getitem__(self, idx: int):
        row = self.ref_dataset[idx]
        gen_image = Image.open(os.path.join(self.gen_dirpath, row["filename"] + ".png")).convert("RGB")
        gen_tensor = torch.from_numpy(np.array(gen_image)).permute(2, 0, 1)
        prompt = row["prompt"]
        return [gen_tensor, prompt]


def compute_image_multimodal_metrics(
    ref_dataset: datasets.Dataset,
    gen_dirpath: str,
    metrics: tuple[str, ...] = ("clip_iqa", "clip_score"),
    batch_size: int = 64,
    num_workers: int = 8,
    device: str | torch.device = "cuda",
) -> dict[str, float]:
    if len(metrics) == 0:
        return {}
    metric_names = metrics
    metrics: dict[str, torchmetrics.Metric] = {}
    for metric_name in metric_names:
        if metric_name == "clip_iqa":
            metric = CLIPImageQualityAssessment(model_name_or_path="openai/clip-vit-large-patch14").to(device)
        elif metric_name == "clip_score":
            metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
        else:
            raise NotImplementedError(f"Metric {metric_name} is not implemented")
        metrics[metric_name] = metric
    dataset = PromptImageDataset(ref_dataset, gen_dirpath)
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
    )
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"{ref_dataset.config_name} multimodal metrics"):
            batch[0] = batch[0].to(device)
            for metric_name, metric in metrics.items():
                if metric_name == "clip_iqa":
                    metric.update(batch[0].to(torch.float32))
                else:
                    prompts = list(batch[1])
                    metric.update(batch[0], prompts)
    result = {metric_name: metric.compute().mean().item() for metric_name, metric in metrics.items()}
    return result


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/run.py
================================================
# -*- coding: utf-8 -*-
"""Evaluate generated images or videos using the specified metrics."""

import json
import os

from ...config import DiffusionPtqRunConfig

if __name__ == "__main__":
    config, _, unused_cfgs, unused_args, unknown_args = DiffusionPtqRunConfig.get_parser().parse_known_args()
    assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
    assert len(unused_cfgs) == 0, f"Unused configurations: {unused_cfgs}"
    assert unused_args is None, f"Unused arguments: {unused_args}"
    assert isinstance(config, DiffusionPtqRunConfig)
    results = config.eval.evaluate(pipeline=None, skip_gen=True, task=config.pipeline.task)
    save_path = os.path.join(config.eval.gen_root, f"results-{config.output.timestamp}.json")
    os.makedirs(os.path.abspath(os.path.dirname(save_path)), exist_ok=True)
    with open(save_path, "w") as f:
        json.dump(results, f, indent=2, sort_keys=True)
    print(results)


================================================
FILE: deepcompressor/app/diffusion/eval/metrics/similarity.py
================================================
import os

import datasets
import torch
import torchmetrics
import torchvision
from PIL import Image
from torch.utils import data
from torchmetrics.image import (
    LearnedPerceptualImagePatchSimilarity,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)
from tqdm import tqdm

__all__ = ["compute_image_similarity_metrics"]


class MultiImageDataset(data.Dataset):
    def __init__(self, gen_dirpath: str, ref_dirpath_or_dataset: str | datasets.Dataset):
        super(data.Dataset, self).__init__()
        self.gen_names = sorted(
            [name for name in os.listdir(gen_dirpath) if name.endswith(".png") or name.endswith(".jpg")]
        )
        self.gen_dirpath, self.ref_dirpath_or_dataset = gen_dirpath, ref_dirpath_or_dataset
        if isinstance(ref_dirpath_or_dataset, str):
            self.ref_names = sorted(
                [name for name in os.listdir(ref_dirpath_or_dataset) if name.endswith(".png") or name.endswith(".jpg")]
            )
            assert len(self.ref_names) == len(self.gen_names)
        else:
            assert isinstance(ref_dirpath_or_dataset, datasets.Dataset)
            self.ref_names = self.gen_names
            assert len(ref_dirpath_or_dataset) == len(self.gen_names)
        self.transform = torchvision.transforms.ToTensor()

    def __len__(self):
        return len(self.ref_names)

    def __getitem__(self, idx: int):
        if isinstance(self.ref_dirpath_or_dataset, str):
            name = self.ref_names[idx]
            assert name == self.gen_names[idx]
            ref_image = Image.open(os.path.join(self.ref_dirpath_or_dataset, name)).convert("RGB")
        else:
            row = self.ref_dirpath_or_dataset[idx]
            ref_image = row["image"].convert("RGB")
            name = row["filename"] + ".png"
        gen_image = Image.open(os.path.join(self.gen_dirpath, name)).convert("RGB")
        gen_size = gen_image.size
        ref_size = ref_image.size
        if ref_size != gen_size:
            ref_image = ref_image.resize(gen_size, Image.Resampling.BICUBIC)
        gen_tensor = self.transform(gen_image)
        ref_tensor = self.transform(ref_image)
        return [gen_tensor, ref_tensor]


def compute_image_similarity_metrics(
    ref_dirpath_or_dataset: str | datasets.Dataset,
    gen_dirpath: str,
    metrics: tuple[str, ...] = ("psnr", "lpips", "ssim"),
    batch_size: int = 64,
    num_workers: int = 8,
    device: str | torch.device = "cuda",
) -> dict[str, float]:
    if len(metrics) == 0:
        return {}
    metric_names = metrics
    metrics: dict[str, torchmetrics.Metric] = {}
    for metric_name in metric_names:
        if metric_name == "psnr":
            metric = PeakSignalNoiseRatio(data_range=(0, 1), reduction="elementwise_mean", dim=(1, 2, 3)).to(device)
        elif metric_name == "lpips":
            metric = LearnedPerceptualImagePatchSimilarity(normalize=True).to(device)
        elif metric_name == "ssim":
            metric = StructuralSimilarityIndexMeasure(data_range=(0, 1)).to(device)
        else:
            raise NotImplementedError(f"Metric {metric_name} is not implemented")
        metrics[metric_name] = metric
    dataset = MultiImageDataset(gen_dirpath, ref_dirpath_or_dataset)
    dataloader = data.DataLoader(
        dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, drop_last=False
    )
    with torch.no_grad():
        desc = (
            ref_dirpath_or_dataset.config_name
            if isinstance(ref_dirpath_or_dataset, datasets.Dataset)
            else os.path.basename(ref_dirpath_or_dataset)
        ) + " similarity metrics"
        for batch in tqdm(dataloader, desc=desc):
            batch = [tensor.to(device) for tensor in batch]
            for metric in metrics.values():
                metric.update(batch[0], batch[1])
    result = {metric_name: metric.compute().item() for metric_name, metric in metrics.items()}
    return result


================================================
FILE: deepcompressor/app/diffusion/nn/__init__.py
================================================
# -*- coding: utf-8 -*-


================================================
FILE: deepcompressor/app/diffusion/nn/attention.py
================================================
# -*- coding: utf-8 -*-

import typing as tp

import diffusers
import packaging.version
import torch
import torch.nn as nn
from diffusers.models.attention_processor import (
    Attention,
    AttnProcessor2_0,
    FluxAttnProcessor2_0,
    JointAttnProcessor2_0,
)

from deepcompressor.nn.patch.sdpa import ScaleDotProductAttention

__all__ = ["DiffusionAttentionProcessor"]


if packaging.version.Version(diffusers.__version__) >= packaging.version.Version("0.31"):
    from diffusers.models.embeddings import apply_rotary_emb

    def apply_flux_rope(query, key, image_rotary_emb):
        query = apply_rotary_emb(query, image_rotary_emb)
        key = apply_rotary_emb(key, image_rotary_emb)
        return query, key

else:
    from diffusers.models.attention_processor import apply_rope as apply_flux_rope


class DiffusionAttentionProcessor(nn.Module):
    def __init__(
        self,
        orig: AttnProcessor2_0 | FluxAttnProcessor2_0 | JointAttnProcessor2_0,
        sdpa: ScaleDotProductAttention | None = None,
    ) -> None:
        super().__init__()
        self.orig = orig
        if orig.__class__.__name__.startswith("Flux"):
            self.rope = apply_flux_rope
        elif isinstance(orig, (AttnProcessor2_0, JointAttnProcessor2_0)):
            self.rope = None
        else:
            raise NotImplementedError(f"Unsupported AttentionProcessor: {orig}")
        self.sdpa = sdpa or ScaleDotProductAttention()

    def __call__(  # noqa: C901
        self,
        attn: Attention,
        hidden_states: torch.Tensor,
        encoder_hidden_states: tp.Optional[torch.Tensor] = None,
        attention_mask: tp.Optional[torch.Tensor] = None,
        image_rotary_emb: tp.Optional[torch.Tensor] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        assert len(args) == 0 and kwargs.get("scale", None) is None
        assert attn.spatial_norm is None
        assert attn.group_norm is None
        assert attn.norm_cross is None
        assert not attn.residual_connection
        assert attn.rescale_output_factor == 1.0
        heads = attn.heads
        head_dim = attn.inner_dim // heads
        kv_heads = attn.inner_kv_dim // head_dim
        assert attn.scale == head_dim**-0.5

        input_ndim, input_shape = hidden_states.dim(), hidden_states.size()
        if input_ndim > 3:
            hidden_states = hidden_states.view(input_shape[0], input_shape[1], -1).transpose(1, 2)
        batch_size, input_length, _ = hidden_states.shape
        context_ndim, context_shape, context_length = None, None, None
        if encoder_hidden_states is not None:
            context_ndim, context_shape = encoder_hidden_states.ndim, encoder_hidden_states.shape
            assert context_shape[0] == batch_size
            if context_ndim > 3:
                encoder_hidden_states = encoder_hidden_states.view(batch_size, context_shape[1], -1).transpose(1, 2)
            context_length = encoder_hidden_states.shape[1]

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, context_length or input_length, batch_size)
            attention_mask = attention_mask.view(batch_size, heads, -1, attention_mask.shape[-1])

        query = attn.to_q(hidden_states)
        key, value, add_query, add_key, add_value = None, None, None, None, None
        if hasattr(attn, "add_k_proj"):
            if attn.to_k is not None:
                key = attn.to_k(hidden_states)
                value = attn.to_v(hidden_states)
            add_key = attn.add_k_proj(encoder_hidden_states)
            add_value = attn.add_v_proj(encoder_hidden_states)
            if hasattr(attn, "add_q_proj"):
                add_query = attn.add_q_proj(encoder_hidden_states)
        else:
            if attn.is_cross_attention:
                key = attn.to_k(encoder_hidden_states)
                value = attn.to_v(encoder_hidden_states)
            else:
                assert encoder_hidden_states is None
                key = attn.to_k(hidden_states)
                value = attn.to_v(hidden_states)
        hidden_states, encoder_hidden_states = None, None

        query = query.view(batch_size, -1, heads, head_dim).transpose(1, 2)
        if key is not None:
            key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
            value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
        if add_query is not None:
            add_query = add_query.view(batch_size, -1, heads, head_dim).transpose(1, 2)
        if add_key is not None:
            add_key = add_key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
            add_value = add_value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)

        if kv_heads != heads:
            heads_per_kv_head = heads // kv_heads
            if key is not None:
                key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
                value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
            if add_key is not None:
                add_key = torch.repeat_interleave(add_key, heads_per_kv_head, dim=1)
                add_value = torch.repeat_interleave(add_value, heads_per_kv_head, dim=1)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
            key = attn.norm_k(key)
        if attn.norm_added_q is not None:
            add_query = attn.norm_added_q(add_query)
            add_key = attn.norm_added_k(add_key)

        if add_query is not None:
            query = torch.cat([add_query, query], dim=2)
        if add_key is not None:
            if key is None:
                key, value = add_key, add_value
            else:
                key = torch.cat([add_key, key], dim=2)
                value = torch.cat([add_value, value], dim=2)
        del add_query, add_key, add_value

        if image_rotary_emb is not None:
            query, key = self.rope(query, key, image_rotary_emb)

        hidden_states = self.sdpa(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.inner_dim)
        hidden_states = hidden_states.to(query.dtype)

        if hidden_states.shape[1] > input_length:
            encoder_hidden_states = hidden_states[:, :context_length]
            hidden_states = hidden_states[:, context_length:]

        if hasattr(attn, "to_out"):
            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)
        if hasattr(attn, "to_add_out"):
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim > 3:
            hidden_states = hidden_states.transpose(-1, -2).reshape(input_shape)
        if encoder_hidden_states is not None and context_ndim > 3:
            assert encoder_hidden_states.ndim == 3
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(context_shape)

        if encoder_hidden_states is None:
            return hidden_states
        return hidden_states, encoder_hidden_states


================================================
FILE: deepcompressor/app/diffusion/nn/patch.py
================================================
import torch.nn as nn
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock

from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.utils import patch, tools

from .attention import DiffusionAttentionProcessor
from .struct import DiffusionFeedForwardStruct, DiffusionModelStruct, DiffusionResnetStruct, UNetStruct

__all__ = [
    "replace_up_block_conv_with_concat_conv",
    "replace_fused_linear_with_concat_linear",
    "replace_attn_processor",
    "shift_input_activations",
]


def replace_up_block_conv_with_concat_conv(model: nn.Module) -> None:
    """Replace up_block convolutions in UNet with ConcatConv."""
    model_struct = DiffusionModelStruct.construct(model)
    if not isinstance(model_struct, UNetStruct):
        return
    logger = tools.logging.getLogger(__name__)
    logger.info("Replacing up_block convolutions with ConcatConv.")
    tools.logging.Formatter.indent_inc()
    parents_map = patch.get_module_parents_map(model)
    for up_block in model_struct.up_block_structs:
        logger.info(f"+ Replacing convolutions in up_block {up_block.name}")
        tools.logging.Formatter.indent_inc()
        for resnet in up_block.resnet_structs:
            assert len(resnet.convs[0]) == 1
            conv, conv_name = resnet.convs[0][0], resnet.conv_names[0][0]
            logger.info(f"- Replacing {conv_name} in resnet {resnet.name}")
            tools.logging.Formatter.indent_inc()
            if resnet.idx == 0:
                if up_block.idx == 0:
                    prev_block = model_struct.mid_block_struct
                else:
                    prev_block = model_struct.up_block_structs[up_block.idx - 1]
                logger.info(f"+ using previous block {prev_block.name}")
                prev_channels = prev_block.resnet_structs[-1].convs[-1][-1].out_channels
            else:
                prev_channels = up_block.resnet_structs[resnet.idx - 1].convs[-1][-1].out_channels
            logger.info(f"+ conv_in_channels = {prev_channels}/{conv.in_channels}")
            logger.info(f"+ conv_out_channels = {conv.out_channels}")
            concat_conv = ConcatConv2d.from_conv2d(conv, [prev_channels])
            for parent_name, parent_module, child_name in parents_map[conv]:
                logger.info(f"+ replacing {child_name} in {parent_name}")
                setattr(parent_module, child_name, concat_conv)
            tools.logging.Formatter.indent_dec()
        tools.logging.Formatter.indent_dec()
    tools.logging.Formatter.indent_dec()


def replace_fused_linear_with_concat_linear(model: nn.Module) -> None:
    """Replace fused Linear in FluxSingleTransformerBlock with ConcatLinear."""
    logger = tools.logging.getLogger(__name__)
    logger.info("Replacing fused Linear with ConcatLinear.")
    tools.logging.Formatter.indent_inc()
    for name, module in model.named_modules():
        if isinstance(module, FluxSingleTransformerBlock):
            logger.info(f"+ Replacing fused Linear in {name} with ConcatLinear.")
            tools.logging.Formatter.indent_inc()
            logger.info(f"- in_features = {module.proj_out.out_features}/{module.proj_out.in_features}")
            logger.info(f"- out_features = {module.proj_out.out_features}")
            tools.logging.Formatter.indent_dec()
            module.proj_out = ConcatLinear.from_linear(module.proj_out, [module.proj_out.out_features])
    tools.logging.Formatter.indent_dec()


def shift_input_activations(model: nn.Module) -> None:
    """Shift input activations of convolutions and linear layers if their lowerbound is negative.

    Args:
        model (nn.Module): model to shift input activations.
    """
    logger = tools.logging.getLogger(__name__)
    model_struct = DiffusionModelStruct.construct(model)
    module_parents_map = patch.get_module_parents_map(model)
    logger.info("- Shifting input activations.")
    tools.logging.Formatter.indent_inc()
    for _, module_name, module, parent, field_name in model_struct.named_key_modules():
        lowerbound = None
        if isinstance(parent, DiffusionResnetStruct) and field_name.startswith("conv"):
            lowerbound = parent.config.intermediate_lowerbound
        elif isinstance(parent, DiffusionFeedForwardStruct) and field_name.startswith("down_proj"):
            lowerbound = parent.config.intermediate_lowerbound
        if lowerbound is not None and lowerbound < 0:
            shift = -lowerbound
            logger.info(f"+ Shifting input activations of {module_name} by {shift}")
            tools.logging.Formatter.indent_inc()
            if isinstance(module, nn.Linear):
                shifted = ShiftedLinear.from_linear(module, shift=shift)
                shifted.linear.unsigned = True
            elif isinstance(module, nn.Conv2d):
                shifted = ShiftedConv2d.from_conv2d(module, shift=shift)
                shifted.conv.unsigned = True
            else:
                raise NotImplementedError(f"Unsupported module type {type(module)}")
            for parent_name, parent_module, child_name in module_parents_map[module]:
                logger.info(f"+ Replacing {child_name} in {parent_name}")
                setattr(parent_module, child_name, shifted)
            tools.logging.Formatter.indent_dec()
    tools.logging.Formatter.indent_dec()


def replace_attn_processor(model: nn.Module) -> None:
    """Replace Attention processor with DiffusionAttentionProcessor."""
    logger = tools.logging.getLogger(__name__)
    logger.info("Replacing Attention processors.")
    tools.logging.Formatter.indent_inc()
    for name, module in model.named_modules():
        if isinstance(module, Attention):
            logger.info(f"+ Replacing {name} processor with DiffusionAttentionProcessor.")
            module.set_processor(DiffusionAttentionProcessor(module.processor))
    tools.logging.Formatter.indent_dec()


================================================
FILE: deepcompressor/app/diffusion/nn/struct.py
================================================
# -*- coding: utf-8 -*-
"""Utility functions for Diffusion Models."""

import enum
import typing as tp
from abc import abstractmethod
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field

# region imports
import torch.nn as nn
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU
from diffusers.models.attention import BasicTransformerBlock, FeedForward, JointTransformerBlock
from diffusers.models.attention_processor import Attention, SanaLinearAttnProcessor2_0
from diffusers.models.embeddings import (
    CombinedTimestepGuidanceTextProjEmbeddings,
    CombinedTimestepTextProjEmbeddings,
    ImageHintTimeEmbedding,
    ImageProjection,
    ImageTimeEmbedding,
    PatchEmbed,
    PixArtAlphaTextProjection,
    TextImageProjection,
    TextImageTimeEmbedding,
    TextTimeEmbedding,
    TimestepEmbedding,
)
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
from diffusers.models.transformers.sana_transformer import GLUMBConv, SanaTransformer2DModel, SanaTransformerBlock
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.models.transformers.transformer_flux import (
    FluxSingleTransformerBlock,
    FluxTransformer2DModel,
    FluxTransformerBlock,
)
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.models.unets.unet_2d import UNet2DModel
from diffusers.models.unets.unet_2d_blocks import (
    CrossAttnDownBlock2D,
    CrossAttnUpBlock2D,
    DownBlock2D,
    UNetMidBlock2D,
    UNetMidBlock2DCrossAttn,
    UpBlock2D,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.pipelines import (
    FluxControlPipeline,
    FluxFillPipeline,
    FluxPipeline,
    PixArtAlphaPipeline,
    PixArtSigmaPipeline,
    SanaPipeline,
    StableDiffusion3Pipeline,
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
)

from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.nn.struct.attn import (
    AttentionConfigStruct,
    AttentionStruct,
    BaseTransformerStruct,
    FeedForwardConfigStruct,
    FeedForwardStruct,
    TransformerBlockStruct,
)
from deepcompressor.nn.struct.base import BaseModuleStruct
from deepcompressor.utils.common import join_name

from .attention import DiffusionAttentionProcessor

# endregion


__all__ = ["DiffusionModelStruct", "DiffusionBlockStruct", "DiffusionModelStruct"]


DIT_BLOCK_CLS = tp.Union[
    BasicTransformerBlock,
    JointTransformerBlock,
    FluxSingleTransformerBlock,
    FluxTransformerBlock,
    SanaTransformerBlock,
]
UNET_BLOCK_CLS = tp.Union[
    DownBlock2D,
    CrossAttnDownBlock2D,
    UNetMidBlock2D,
    UNetMidBlock2DCrossAttn,
    UpBlock2D,
    CrossAttnUpBlock2D,
]
DIT_CLS = tp.Union[
    Transformer2DModel,
    PixArtTransformer2DModel,
    SD3Transformer2DModel,
    FluxTransformer2DModel,
    SanaTransformer2DModel,
]
UNET_CLS = tp.Union[UNet2DModel, UNet2DConditionModel]
MODEL_CLS = tp.Union[DIT_CLS, UNET_CLS]
UNET_PIPELINE_CLS = tp.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
DIT_PIPELINE_CLS = tp.Union[
    StableDiffusion3Pipeline,
    PixArtAlphaPipeline,
    PixArtSigmaPipeline,
    FluxPipeline,
    FluxControlPipeline,
    FluxFillPipeline,
    SanaPipeline,
]
PIPELINE_CLS = tp.Union[UNET_PIPELINE_CLS, DIT_PIPELINE_CLS]


@dataclass(kw_only=True)
class DiffusionModuleStruct(BaseModuleStruct):
    def named_key_modules(self) -> tp.Generator[tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
        if isinstance(self.module, (nn.Linear, nn.Conv2d)):
            yield self.key, self.name, self.module, self.parent, self.fname
        else:
            for name, module in self.module.named_modules():
                if name and isinstance(module, (nn.Linear, nn.Conv2d)):
                    module_name = join_name(self.name, name)
                    field_name = join_name(self.fname, name)
                    yield self.key, module_name, module, self.parent, field_name


@dataclass(kw_only=True)
class DiffusionBlockStruct(BaseModuleStruct):
    @abstractmethod
    def iter_attention_structs(self) -> tp.Generator["DiffusionAttentionStruct", None, None]: ...

    @abstractmethod
    def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]: ...


@dataclass(kw_only=True)
class DiffusionModelStruct(DiffusionBlockStruct):
    pre_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)
    post_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)

    @property
    @abstractmethod
    def num_blocks(self) -> int: ...

    @property
    @abstractmethod
    def block_structs(self) -> list[DiffusionBlockStruct]: ...

    @abstractmethod
    def get_prev_module_keys(self) -> tuple[str, ...]: ...

    @abstractmethod
    def get_post_module_keys(self) -> tuple[str, ...]: ...

    @abstractmethod
    def _get_iter_block_activations_args(
        self, **input_kwargs
    ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]: ...

    def _get_iter_pre_module_activations_args(
        self,
    ) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
        layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
        for layer_struct in self.pre_module_structs.values():
            layers.append(layer_struct.module)
            layer_structs.append(layer_struct)
            recomputes.append(False)
            use_prev_layer_outputs.append(False)
        return layers, layer_structs, recomputes, use_prev_layer_outputs

    def _get_iter_post_module_activations_args(
        self,
    ) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
        layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
        for layer_struct in self.post_module_structs.values():
            layers.append(layer_struct.module)
            layer_structs.append(layer_struct)
            recomputes.append(False)
            use_prev_layer_outputs.append(False)
        return layers, layer_structs, recomputes, use_prev_layer_outputs

    def get_iter_layer_activations_args(
        self, skip_pre_modules: bool, skip_post_modules: bool, **input_kwargs
    ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
        """
        Get the arguments for iterating over the layers and their activations.

        Args:
            skip_pre_modules (`bool`):
                Whether to skip the pre-modules
            skip_post_modules (`bool`):
                Whether to skip the post-modules

        Returns:
            `tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]`:
                the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
        """
        layers, structs, recomputes, uses = [], [], [], []
        if not skip_pre_modules:
            layers, structs, recomputes, uses = self._get_iter_pre_module_activations_args()
        _layers, _structs, _recomputes, _uses = self._get_iter_block_activations_args(**input_kwargs)
        layers.extend(_layers)
        structs.extend(_structs)
        recomputes.extend(_recomputes)
        uses.extend(_uses)
        if not skip_post_modules:
            _layers, _structs, _recomputes, _uses = self._get_iter_post_module_activations_args()
            layers.extend(_layers)
            structs.extend(_structs)
            recomputes.extend(_recomputes)
            uses.extend(_uses)
        return layers, structs, recomputes, uses

    def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
        for module in self.pre_module_structs.values():
            yield from module.named_key_modules()
        for block in self.block_structs:
            yield from block.named_key_modules()
        for module in self.post_module_structs.values():
            yield from module.named_key_modules()

    def iter_attention_structs(self) -> tp.Generator["AttentionStruct", None, None]:
        for block in self.block_structs:
            yield from block.iter_attention_structs()

    def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]:
        for block in self.block_structs:
            yield from block.iter_transformer_block_structs()

    def get_named_layers(
        self, skip_pre_modules: bool, skip_post_modules: bool, skip_blocks: bool = False
    ) -> OrderedDict[str, DiffusionBlockStruct | DiffusionModuleStruct]:
        named_layers = OrderedDict()
        if not skip_pre_modules:
            named_layers.update(self.pre_module_structs)
        if not skip_blocks:
            for block in self.block_structs:
                named_layers[block.name] = block
        if not skip_post_modules:
            named_layers.update(self.post_module_structs)
        return named_layers

    @staticmethod
    def _default_construct(
        module: tp.Union[PIPELINE_CLS, MODEL_CLS],
        /,
        parent: tp.Optional[BaseModuleStruct] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionModelStruct":
        if isinstance(module, UNET_PIPELINE_CLS):
            module = module.unet
        elif isinstance(module, DIT_PIPELINE_CLS):
            module = module.transformer
        if isinstance(module, UNET_CLS):
            return UNetStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
        elif isinstance(module, DIT_CLS):
            return DiTStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
        raise NotImplementedError(f"Unsupported module type: {type(module)}")

    @classmethod
    def _get_default_key_map(cls) -> dict[str, set[str]]:
        unet_key_map = UNetStruct._get_default_key_map()
        dit_key_map = DiTStruct._get_default_key_map()
        flux_key_map = FluxStruct._get_default_key_map()
        key_map: dict[str, set[str]] = defaultdict(set)
        for rkey, keys in unet_key_map.items():
            key_map[rkey].update(keys)
        for rkey, keys in dit_key_map.items():
            key_map[rkey].update(keys)
        for rkey, keys in flux_key_map.items():
            key_map[rkey].update(keys)
        return {k: v for k, v in key_map.items() if v}

    @staticmethod
    def _simplify_keys(keys: tp.Iterable[str], *, key_map: dict[str, set[str]]) -> list[str]:
        """Simplify the keys based on the key map.

        Args:
            keys (`Iterable[str]`):
                The keys to simplify.
            key_map (`dict[str, set[str]]`):
                The key map.

        Returns:
            `list[str]`:
                The simplified keys.
        """
        # we first sort key_map by length of values in descending order
        key_map = dict(sorted(key_map.items(), key=lambda item: len(item[1]), reverse=True))
        ukeys, skeys = set(keys), set()
        for k, v in key_map.items():
            if k in ukeys:
                skeys.add(k)
                ukeys.discard(k)
                ukeys.difference_update(v)
                continue
            if ukeys.issuperset(v):
                skeys.add(k)
                ukeys.difference_update(v)
        assert not ukeys, f"Unrecognized keys: {ukeys}"
        return sorted(skeys)


@dataclass(kw_only=True)
class DiffusionAttentionStruct(AttentionStruct):
    module: Attention = field(repr=False, kw_only=False)
    """the module of AttentionBlock"""
    parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)

    def filter_kwargs(self, kwargs: dict) -> dict:
        """Filter layer kwargs to attn kwargs."""
        if isinstance(self.parent.module, BasicTransformerBlock):
            if kwargs.get("cross_attention_kwargs", None) is None:
                attn_kwargs = {}
            else:
                attn_kwargs = dict(kwargs["cross_attention_kwargs"].items())
            attn_kwargs.pop("gligen", None)
            if self.idx == 0:
                attn_kwargs["attention_mask"] = kwargs.get("attention_mask", None)
            else:
                attn_kwargs["attention_mask"] = kwargs.get("encoder_attention_mask", None)
        else:
            attn_kwargs = {}
        return attn_kwargs

    @staticmethod
    def _default_construct(
        module: Attention,
        /,
        parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionAttentionStruct":
        if module.is_cross_attention:
            q_proj, k_proj, v_proj = module.to_q, None, None
            add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
            q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
            add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
        else:
            q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
            add_q_proj = getattr(module, "add_q_proj", None)
            add_k_proj = getattr(module, "add_k_proj", None)
            add_v_proj = getattr(module, "add_v_proj", None)
            add_o_proj = getattr(module, "to_add_out", None)
            q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
            add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
            add_o_proj_rname = "to_add_out"
        if getattr(module, "to_out", None) is not None:
            o_proj = module.to_out[0]
            o_proj_rname = "to_out.0"
            assert isinstance(o_proj, nn.Linear)
        elif parent is not None:
            assert isinstance(parent.module, FluxSingleTransformerBlock)
            assert isinstance(parent.module.proj_out, ConcatLinear)
            assert len(parent.module.proj_out.linears) == 2
            o_proj = parent.module.proj_out.linears[0]
            o_proj_rname = ".proj_out.linears.0"
        else:
            raise RuntimeError("Cannot find the output projection.")
        if isinstance(module.processor, DiffusionAttentionProcessor):
            with_rope = module.processor.rope is not None
        elif module.processor.__class__.__name__.startswith("Flux"):
            with_rope = True
        else:
            with_rope = False  # TODO: fix for other processors
        config = AttentionConfigStruct(
            hidden_size=q_proj.weight.shape[1],
            add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
            inner_size=q_proj.weight.shape[0],
            num_query_heads=module.heads,
            num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
            with_qk_norm=module.norm_q is not None,
            with_rope=with_rope,
            linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
        )
        return DiffusionAttentionStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            config=config,
            q_proj=q_proj,
            k_proj=k_proj,
            v_proj=v_proj,
            o_proj=o_proj,
            add_q_proj=add_q_proj,
            add_k_proj=add_k_proj,
            add_v_proj=add_v_proj,
            add_o_proj=add_o_proj,
            q=None,  # TODO: add q, k, v
            k=None,
            v=None,
            q_proj_rname=q_proj_rname,
            k_proj_rname=k_proj_rname,
            v_proj_rname=v_proj_rname,
            o_proj_rname=o_proj_rname,
            add_q_proj_rname=add_q_proj_rname,
            add_k_proj_rname=add_k_proj_rname,
            add_v_proj_rname=add_v_proj_rname,
            add_o_proj_rname=add_o_proj_rname,
            q_rname="",
            k_rname="",
            v_rname="",
        )


@dataclass(kw_only=True)
class DiffusionFeedForwardStruct(FeedForwardStruct):
    module: FeedForward = field(repr=False, kw_only=False)
    """the module of FeedForward"""
    parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)
    # region modules
    moe_gate: None = field(init=False, repr=False, default=None)
    experts: list[nn.Module] = field(init=False, repr=False)
    # endregion
    # region names
    moe_gate_rname: str = field(init=False, repr=False, default="")
    experts_rname: str = field(init=False, repr=False, default="")
    # endregion

    # region aliases

    @property
    def up_proj(self) -> nn.Linear:
        return self.up_projs[0]

    @property
    def down_proj(self) -> nn.Linear:
        return self.down_projs[0]

    @property
    def up_proj_rname(self) -> str:
        return self.up_proj_rnames[0]

    @property
    def down_proj_rname(self) -> str:
        return self.down_proj_rnames[0]

    @property
    def up_proj_name(self) -> str:
        return self.up_proj_names[0]

    @property
    def down_proj_name(self) -> str:
        return self.down_proj_names[0]

    # endregion

    def __post_init__(self) -> None:
        assert len(self.up_projs) == len(self.down_projs) == 1
        assert len(self.up_proj_rnames) == len(self.down_proj_rnames) == 1
        self.experts = [self.module]
        super().__post_init__()

    @staticmethod
    def _default_construct(
        module: FeedForward | FluxSingleTransformerBlock | GLUMBConv,
        /,
        parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionFeedForwardStruct":
        if isinstance(module, FeedForward):
            layer_1, layer_2 = module.net[0], module.net[2]
            assert isinstance(layer_1, (GEGLU, GELU, ApproximateGELU, SwiGLU))
            up_proj, up_proj_rname = layer_1.proj, "net.0.proj"
            assert isinstance(up_proj, nn.Linear)
            down_proj, down_proj_rname = layer_2, "net.2"
            if isinstance(layer_1, GEGLU):
                act_type = "gelu_glu"
            elif isinstance(layer_1, SwiGLU):
                act_type = "swish_glu"
            else:
                assert layer_1.__class__.__name__.lower().endswith("gelu")
                act_type = "gelu"
                if isinstance(layer_2, ShiftedLinear):
                    down_proj, down_proj_rname = layer_2.linear, "net.2.linear"
                    act_type = "gelu_shifted"
            assert isinstance(down_proj, nn.Linear)
            ffn = module
        elif isinstance(module, FluxSingleTransformerBlock):
            up_proj, up_proj_rname = module.proj_mlp, "proj_mlp"
            act_type = "gelu"
            assert isinstance(module.proj_out, ConcatLinear)
            assert len(module.proj_out.linears) == 2
            layer_2 = module.proj_out.linears[1]
            if isinstance(layer_2, ShiftedLinear):
                down_proj, down_proj_rname = layer_2.linear, "proj_out.linears.1.linear"
                act_type = "gelu_shifted"
            else:
                down_proj, down_proj_rname = layer_2, "proj_out.linears.1"
            ffn = nn.Sequential(up_proj, module.act_mlp, layer_2)
            assert not rname, f"Unsupported rname: {rname}"
        elif isinstance(module, GLUMBConv):
            ffn = module
            up_proj, up_proj_rname = module.conv_inverted, "conv_inverted"
            down_proj, down_proj_rname = module.conv_point, "conv_point"
            act_type = "silu_conv_silu_glu"
        else:
            raise NotImplementedError(f"Unsupported module type: {type(module)}")
        config = FeedForwardConfigStruct(
            hidden_size=up_proj.weight.shape[1],
            intermediate_size=down_proj.weight.shape[1],
            intermediate_act_type=act_type,
            num_experts=1,
        )
        return DiffusionFeedForwardStruct(
            module=ffn,  # this may be a virtual module
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            config=config,
            up_projs=[up_proj],
            down_projs=[down_proj],
            up_proj_rnames=[up_proj_rname],
            down_proj_rnames=[down_proj_rname],
        )


@dataclass(kw_only=True)
class DiffusionTransformerBlockStruct(TransformerBlockStruct, DiffusionBlockStruct):
    # region relative keys
    norm_rkey: tp.ClassVar[str] = "transformer_norm"
    add_norm_rkey: tp.ClassVar[str] = "transformer_add_norm"
    attn_struct_cls: tp.ClassVar[type[DiffusionAttentionStruct]] = DiffusionAttentionStruct
    ffn_struct_cls: tp.ClassVar[type[DiffusionFeedForwardStruct]] = DiffusionFeedForwardStruct
    # endregion

    parent: tp.Optional["DiffusionTransformerStruct"] = field(repr=False)
    # region child modules
    post_attn_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
    post_attn_add_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
    post_ffn_norm: None = field(init=False, repr=False, default=None)
    post_add_ffn_norm: None = field(init=False, repr=False, default=None)
    # endregion
    # region relative names
    post_attn_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
    post_attn_add_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
    post_ffn_norm_rname: str = field(init=False, repr=False, default="")
    post_add_ffn_norm_rname: str = field(init=False, repr=False, default="")
    # endregion
    # region attributes
    norm_type: str
    add_norm_type: str
    # endregion
    # region absolute keys
    norm_key: str = field(init=False, repr=False)
    add_norm_key: str = field(init=False, repr=False)
    # endregion
    # region child structs
    pre_attn_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
    pre_attn_add_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
    pre_ffn_norm_struct: DiffusionModuleStruct = field(init=False, repr=False, default=None)
    pre_add_ffn_norm_struct: DiffusionModuleStruct | None = field(init=False, repr=False, default=None)
    attn_structs: list[DiffusionAttentionStruct] = field(init=False, repr=False)
    ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
    add_ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
    # endregion

    def __post_init__(self) -> None:
        super().__post_init__()
        self.norm_key = join_name(self.key, self.norm_rkey, sep="_")
        self.add_norm_key = join_name(self.key, self.add_norm_rkey, sep="_")
        self.attn_norm_structs = [
            DiffusionModuleStruct(norm, parent=self, fname="pre_attn_norm", rname=rname, rkey=self.norm_rkey, idx=idx)
            for idx, (norm, rname) in enumerate(zip(self.pre_attn_norms, self.pre_attn_norm_rnames, strict=True))
        ]
        self.add_attn_norm_structs = [
            DiffusionModuleStruct(
                norm, parent=self, fname="pre_attn_add_norm", rname=rname, rkey=self.add_norm_rkey, idx=idx
            )
            for idx, (norm, rname) in enumerate(
                zip(self.pre_attn_add_norms, self.pre_attn_add_norm_rnames, strict=True)
            )
        ]
        if self.pre_ffn_norm is not None:
            self.pre_ffn_norm_struct = DiffusionModuleStruct(
                self.pre_ffn_norm, parent=self, fname="pre_ffn_norm", rname=self.pre_ffn_norm_rname, rkey=self.norm_rkey
            )
        if self.pre_add_ffn_norm is not None:
            self.pre_add_ffn_norm_struct = DiffusionModuleStruct(
                self.pre_add_ffn_norm,
                parent=self,
                fname="pre_add_ffn_norm",
                rname=self.pre_add_ffn_norm_rname,
                rkey=self.add_norm_rkey,
            )

    def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
        for attn_norm in self.attn_norm_structs:
            if attn_norm.module is not None:
                yield from attn_norm.named_key_modules()
        for add_attn_norm in self.add_attn_norm_structs:
            if add_attn_norm.module is not None:
                yield from add_attn_norm.named_key_modules()
        for attn_struct in self.attn_structs:
            yield from attn_struct.named_key_modules()
        if self.pre_ffn_norm_struct is not None:
            if self.pre_attn_norms and self.pre_attn_norms[0] is not self.pre_ffn_norm:
                yield from self.pre_ffn_norm_struct.named_key_modules()
        if self.ffn_struct is not None:
            yield from self.ffn_struct.named_key_modules()
        if self.pre_add_ffn_norm_struct is not None:
            if self.pre_attn_add_norms and self.pre_attn_add_norms[0] is not self.pre_add_ffn_norm:
                yield from self.pre_add_ffn_norm_struct.named_key_modules()
        if self.add_ffn_struct is not None:
            yield from self.add_ffn_struct.named_key_modules()

    @staticmethod
    def _default_construct(
        module: DIT_BLOCK_CLS,
        /,
        parent: tp.Optional["DiffusionTransformerStruct"] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionTransformerBlockStruct":
        if isinstance(module, (BasicTransformerBlock, SanaTransformerBlock)):
            parallel = False
            if isinstance(module, SanaTransformerBlock):
                norm_type = add_norm_type = "ada_norm_single"
            else:
                norm_type = add_norm_type = module.norm_type
            pre_attn_norms, pre_attn_norm_rnames = [], []
            attns, attn_rnames = [], []
            pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
            assert module.norm1 is not None
            assert module.attn1 is not None
            pre_attn_norms.append(module.norm1)
            pre_attn_norm_rnames.append("norm1")
            attns.append(module.attn1)
            attn_rnames.append("attn1")
            pre_attn_add_norms.append(module.attn1.norm_cross)
            pre_attn_add_norm_rnames.append("attn1.norm_cross")
            if module.attn2 is not None:
                if norm_type == "ada_norm_single":
                    pre_attn_norms.append(None)
                    pre_attn_norm_rnames.append("")
                else:
                    assert module.norm2 is not None
                    pre_attn_norms.append(module.norm2)
                    pre_attn_norm_rnames.append("norm2")
                attns.append(module.attn2)
                attn_rnames.append("attn2")
                pre_attn_add_norms.append(module.attn2.norm_cross)
                pre_attn_add_norm_rnames.append("attn2.norm_cross")
            if norm_type == "ada_norm_single":
                assert module.norm2 is not None
                pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
            else:
                pre_ffn_norm, pre_ffn_norm_rname = module.norm3, "" if module.norm3 is None else "norm3"
            ffn, ffn_rname = module.ff, "" if module.ff is None else "ff"
            pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
        elif isinstance(module, JointTransformerBlock):
            parallel = False
            norm_type = "ada_norm_zero"
            pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
            if isinstance(module.norm1_context, AdaLayerNormZero):
                add_norm_type = "ada_norm_zero"
            else:
                add_norm_type = "ada_norm_continous"
            pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
            attns, attn_rnames = [module.attn], ["attn"]
            pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
            ffn, ffn_rname = module.ff, "ff"
            pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
            add_ffn, add_ffn_rname = module.ff_context, "ff_context"
        elif isinstance(module, FluxSingleTransformerBlock):
            parallel = True
            norm_type = add_norm_type = "ada_norm_zero_single"
            pre_attn_norms, pre_attn_norm_rnames = [module.norm], ["norm"]
            attns, attn_rnames = [module.attn], ["attn"]
            pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
            pre_ffn_norm, pre_ffn_norm_rname = module.norm, "norm"
            ffn, ffn_rname = module, ""
            pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
        elif isinstance(module, FluxTransformerBlock):
            parallel = False
            norm_type = add_norm_type = "ada_norm_zero"
            pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
            attns, attn_rnames = [module.attn], ["attn"]
            pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
            pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
            ffn, ffn_rname = module.ff, "ff"
            pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
            add_ffn, add_ffn_rname = module.ff_context, "ff_context"
        else:
            raise NotImplementedError(f"Unsupported module type: {type(module)}")
        return DiffusionTransformerBlockStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            parallel=parallel,
            pre_attn_norms=pre_attn_norms,
            pre_attn_add_norms=pre_attn_add_norms,
            attns=attns,
            pre_ffn_norm=pre_ffn_norm,
            ffn=ffn,
            pre_add_ffn_norm=pre_add_ffn_norm,
            add_ffn=add_ffn,
            pre_attn_norm_rnames=pre_attn_norm_rnames,
            pre_attn_add_norm_rnames=pre_attn_add_norm_rnames,
            attn_rnames=attn_rnames,
            pre_ffn_norm_rname=pre_ffn_norm_rname,
            ffn_rname=ffn_rname,
            pre_add_ffn_norm_rname=pre_add_ffn_norm_rname,
            add_ffn_rname=add_ffn_rname,
            norm_type=norm_type,
            add_norm_type=add_norm_type,
        )

    @classmethod
    def _get_default_key_map(cls) -> dict[str, set[str]]:
        """Get the default allowed keys."""
        key_map: dict[str, set[str]] = defaultdict(set)
        norm_rkey = norm_key = cls.norm_rkey
        add_norm_rkey = add_norm_key = cls.add_norm_rkey
        key_map[norm_rkey].add(norm_key)
        key_map[add_norm_rkey].add(add_norm_key)
        attn_cls = cls.attn_struct_cls
        attn_key = attn_rkey = cls.attn_rkey
        qkv_proj_key = qkv_proj_rkey = join_name(attn_key, attn_cls.qkv_proj_rkey, sep="_")
        out_proj_key = out_proj_rkey = join_name(attn_key, attn_cls.out_proj_rkey, sep="_")
        add_qkv_proj_key = add_qkv_proj_rkey = join_name(attn_key, attn_cls.add_qkv_proj_rkey, sep="_")
        add_out_proj_key = add_out_proj_rkey = join_name(attn_key, attn_cls.add_out_proj_rkey, sep="_")
        key_map[attn_rkey].add(qkv_proj_key)
        key_map[attn_rkey].add(out_proj_key)
        if attn_cls.add_qkv_proj_rkey.startswith("add_") and attn_cls.add_out_proj_rkey.startswith("add_"):
            add_attn_rkey = join_name(attn_rkey, "add", sep="_")
            key_map[add_attn_rkey].add(add_qkv_proj_key)
            key_map[add_attn_rkey].add(add_out_proj_key)
        key_map[qkv_proj_rkey].add(qkv_proj_key)
        key_map[out_proj_rkey].add(out_proj_key)
        key_map[add_qkv_proj_rkey].add(add_qkv_proj_key)
        key_map[add_out_proj_rkey].add(add_out_proj_key)
        ffn_cls = cls.ffn_struct_cls
        ffn_key = ffn_rkey = cls.ffn_rkey
        add_ffn_key = add_ffn_rkey = cls.add_ffn_rkey
        up_proj_key = up_proj_rkey = join_name(ffn_key, ffn_cls.up_proj_rkey, sep="_")
        down_proj_key = down_proj_rkey = join_name(ffn_key, ffn_cls.down_proj_rkey, sep="_")
        add_up_proj_key = add_up_proj_rkey = join_name(add_ffn_key, ffn_cls.up_proj_rkey, sep="_")
        add_down_proj_key = add_down_proj_rkey = join_name(add_ffn_key, ffn_cls.down_proj_rkey, sep="_")
        key_map[ffn_rkey].add(up_proj_key)
        key_map[ffn_rkey].add(down_proj_key)
        key_map[add_ffn_rkey].add(add_up_proj_key)
        key_map[add_ffn_rkey].add(add_down_proj_key)
        key_map[up_proj_rkey].add(up_proj_key)
        key_map[down_proj_rkey].add(down_proj_key)
        key_map[add_up_proj_rkey].add(add_up_proj_key)
        key_map[add_down_proj_rkey].add(add_down_proj_key)
        return {k: v for k, v in key_map.items() if v}


@dataclass(kw_only=True)
class DiffusionTransformerStruct(BaseTransformerStruct, DiffusionBlockStruct):
    # region relative keys
    proj_in_rkey: tp.ClassVar[str] = "transformer_proj_in"
    proj_out_rkey: tp.ClassVar[str] = "transformer_proj_out"
    transformer_block_rkey: tp.ClassVar[str] = ""
    transformer_block_struct_cls: tp.ClassVar[type[DiffusionTransformerBlockStruct]] = DiffusionTransformerBlockStruct
    # endregion

    module: Transformer2DModel = field(repr=False, kw_only=False)
    # region modules
    norm_in: nn.GroupNorm | None
    """Input normalization"""
    proj_in: nn.Linear | nn.Conv2d
    """Input projection"""
    norm_out: nn.GroupNorm | None
    """Output normalization"""
    proj_out: nn.Linear | nn.Conv2d
    """Output projection"""
    transformer_blocks: nn.ModuleList = field(repr=False)
    """Transformer blocks"""
    # endregion
    # region relative names
    transformer_blocks_rname: str
    # endregion
    # region absolute names
    transformer_blocks_name: str = field(init=False, repr=False)
    transformer_block_names: list[str] = field(init=False, repr=False)
    # endregion
    # region child structs
    transformer_block_structs: list[DiffusionTransformerBlockStruct] = field(init=False, repr=False)
    # endregion

    # region aliases

    @property
    def num_blocks(self) -> int:
        return len(self.transformer_blocks)

    @property
    def block_structs(self) -> list[DiffusionBlockStruct]:
        return self.transformer_block_structs

    @property
    def block_names(self) -> list[str]:
        return self.transformer_block_names

    # endregion

    def __post_init__(self):
        super().__post_init__()
        transformer_block_rnames = [
            f"{self.transformer_blocks_rname}.{idx}" for idx in range(len(self.transformer_blocks))
        ]
        self.transformer_blocks_name = join_name(self.name, self.transformer_blocks_rname)
        self.transformer_block_names = [join_name(self.name, rname) for rname in transformer_block_rnames]
        self.transformer_block_structs = [
            self.transformer_block_struct_cls.construct(
                layer,
                parent=self,
                fname="transformer_block",
                rname=rname,
                rkey=self.transformer_block_rkey,
                idx=idx,
            )
            for idx, (layer, rname) in enumerate(zip(self.transformer_blocks, transformer_block_rnames, strict=True))
        ]

    @staticmethod
    def _default_construct(
        module: Transformer2DModel,
        /,
        parent: BaseModuleStruct = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionTransformerStruct":
        if isinstance(module, Transformer2DModel):
            assert module.is_input_continuous, "input must be continuous"
            transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
            norm_in, norm_in_rname = module.norm, "norm"
            proj_in, proj_in_rname = module.proj_in, "proj_in"
            proj_out, proj_out_rname = module.proj_out, "proj_out"
            norm_out, norm_out_rname = None, ""
        else:
            raise NotImplementedError(f"Unsupported module type: {type(module)}")
        return DiffusionTransformerStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            norm_in=norm_in,
            proj_in=proj_in,
            transformer_blocks=transformer_blocks,
            proj_out=proj_out,
            norm_out=norm_out,
            norm_in_rname=norm_in_rname,
            proj_in_rname=proj_in_rname,
            transformer_blocks_rname=transformer_blocks_rname,
            norm_out_rname=norm_out_rname,
            proj_out_rname=proj_out_rname,
        )

    @classmethod
    def _get_default_key_map(cls) -> dict[str, set[str]]:
        """Get the default allowed keys."""
        key_map: dict[str, set[str]] = defaultdict(set)
        proj_in_rkey = proj_in_key = cls.proj_in_rkey
        proj_out_rkey = proj_out_key = cls.proj_out_rkey
        key_map[proj_in_rkey].add(proj_in_key)
        key_map[proj_out_rkey].add(proj_out_key)
        block_cls = cls.transformer_block_struct_cls
        block_key = block_rkey = cls.transformer_block_rkey
        block_key_map = block_cls._get_default_key_map()
        for rkey, keys in block_key_map.items():
            rkey = join_name(block_rkey, rkey, sep="_")
            for key in keys:
                key = join_name(block_key, key, sep="_")
                key_map[rkey].add(key)
        return {k: v for k, v in key_map.items() if v}


@dataclass(kw_only=True)
class DiffusionResnetStruct(BaseModuleStruct):
    # region relative keys
    conv_rkey: tp.ClassVar[str] = "conv"
    shortcut_rkey: tp.ClassVar[str] = "shortcut"
    time_proj_rkey: tp.ClassVar[str] = "time_proj"
    # endregion

    module: ResnetBlock2D = field(repr=False, kw_only=False)
    """the module of Resnet"""
    config: FeedForwardConfigStruct
    # region child modules
    norms: list[nn.GroupNorm]
    convs: list[list[nn.Conv2d]]
    shortcut: nn.Conv2d | None
    time_proj: nn.Linear | None
    # endregion
    # region relative names
    norm_rnames: list[str]
    conv_rnames: list[list[str]]
    shortcut_rname: str
    time_proj_rname: str
    # endregion
    # region absolute names
    norm_names: list[str] = field(init=False, repr=False)
    conv_names: list[list[str]] = field(init=False, repr=False)
    shortcut_name: str = field(init=False, repr=False)
    time_proj_name: str = field(init=False, repr=False)
    # endregion
    # region absolute keys
    conv_key: str = field(init=False, repr=False)
    shortcut_key: str = field(init=False, repr=False)
    time_proj_key: str = field(init=False, repr=False)
    # endregion

    def __post_init__(self):
        super().__post_init__()
        self.norm_names = [join_name(self.name, rname) for rname in self.norm_rnames]
        self.conv_names = [[join_name(self.name, rname) for rname in rnames] for rnames in self.conv_rnames]
        self.shortcut_name = join_name(self.name, self.shortcut_rname)
        self.time_proj_name = join_name(self.name, self.time_proj_rname)
        self.conv_key = join_name(self.key, self.conv_rkey, sep="_")
        self.shortcut_key = join_name(self.key, self.shortcut_rkey, sep="_")
        self.time_proj_key = join_name(self.key, self.time_proj_rkey, sep="_")

    def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
        for convs, names in zip(self.convs, self.conv_names, strict=True):
            for conv, name in zip(convs, names, strict=True):
                yield self.conv_key, name, conv, self, "conv"
        if self.shortcut is not None:
            yield self.shortcut_key, self.shortcut_name, self.shortcut, self, "shortcut"
        if self.time_proj is not None:
            yield self.time_proj_key, self.time_proj_name, self.time_proj, self, "time_proj"

    @staticmethod
    def construct(
        module: ResnetBlock2D,
        /,
        parent: BaseModuleStruct = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "DiffusionResnetStruct":
        if isinstance(module, ResnetBlock2D):
            assert module.upsample is None, "upsample must be None"
            assert module.downsample is None, "downsample must be None"
            act_type = module.nonlinearity.__class__.__name__.lower()
            shifted = False
            if isinstance(module.conv1, ConcatConv2d):
                conv1_convs, conv1_names = [], []
                for conv_idx, conv in enumerate(module.conv1.convs):
                    if isinstance(conv, ShiftedConv2d):
                        shifted = True
                        conv1_convs.append(conv.conv)
                        conv1_names.append(f"conv1.convs.{conv_idx}.conv")
                    else:
                        assert isinstance(conv, nn.Conv2d)
                        conv1_convs.append(conv)
                        conv1_names.append(f"conv1.convs.{conv_idx}")
            elif isinstance(module.conv1, ShiftedConv2d):
                shifted = True
                conv1_convs = [module.conv1.conv]
                conv1_names = ["conv1.conv"]
            else:
                assert isinstance(module.conv1, nn.Conv2d)
                conv1_convs, conv1_names = [module.conv1], ["conv1"]
            if isinstance(module.conv2, ConcatConv2d):
                conv2_convs, conv2_names = [], []
                for conv_idx, conv in enumerate(module.conv2.convs):
                    if isinstance(conv, ShiftedConv2d):
                        shifted = True
                        conv2_convs.append(conv.conv)
                        conv2_names.append(f"conv2.convs.{conv_idx}.conv")
                    else:
                        assert isinstance(conv, nn.Conv2d)
                        conv2_convs.append(conv)
                        conv2_names.append(f"conv2.convs.{conv_idx}")
            elif isinstance(module.conv2, ShiftedConv2d):
                shifted = True
                conv2_convs = [module.conv2.conv]
                conv2_names = ["conv2.conv"]
            else:
                assert isinstance(module.conv2, nn.Conv2d)
                conv2_convs, conv2_names = [module.conv2], ["conv2"]
            convs, conv_rnames = [conv1_convs, conv2_convs], [conv1_names, conv2_names]
            norms, norm_rnames = [module.norm1, module.norm2], ["norm1", "norm2"]
            shortcut, shortcut_rname = module.conv_shortcut, "" if module.conv_shortcut is None else "conv_shortcut"
            time_proj, time_proj_rname = module.time_emb_proj, "" if module.time_emb_proj is None else "time_emb_proj"
            if shifted:
                assert all(hasattr(conv, "shifted") and conv.shifted for level_convs in convs for conv in level_convs)
                act_type += "_shifted"
        else:
            raise NotImplementedError(f"Unsupported module type: {type(module)}")
        config = FeedForwardConfigStruct(
            hidden_size=convs[0][0].weight.shape[1],
            intermediate_size=convs[0][0].weight.shape[0],
            intermediate_act_type=act_type,
            num_experts=1,
        )
        return DiffusionResnetStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            config=config,
            norms=norms,
            convs=convs,
            shortcut=shortcut,
            time_proj=time_proj,
            norm_rnames=norm_rnames,
            conv_rnames=conv_rnames,
            shortcut_rname=shortcut_rname,
            time_proj_rname=time_proj_rname,
        )

    @classmethod
    def _get_default_key_map(cls) -> dict[str, set[str]]:
        """Get the default allowed keys."""
        key_map: dict[str, set[str]] = defaultdict(set)
        conv_key = conv_rkey = cls.conv_rkey
        shortcut_key = shortcut_rkey = cls.shortcut_rkey
        time_proj_key = time_proj_rkey = cls.time_proj_rkey
        key_map[conv_rkey].add(conv_key)
        key_map[shortcut_rkey].add(shortcut_key)
        key_map[time_proj_rkey].add(time_proj_key)
        return {k: v for k, v in key_map.items() if v}


@dataclass(kw_only=True)
class UNetBlockStruct(DiffusionBlockStruct):
    class BlockType(enum.StrEnum):
        DOWN = "down"
        MID = "mid"
        UP = "up"

    # region relative keys
    resnet_rkey: tp.ClassVar[str] = "resblock"
    sampler_rkey: tp.ClassVar[str] = "sample"
    transformer_rkey: tp.ClassVar[str] = ""
    resnet_struct_cls: tp.ClassVar[type[DiffusionResnetStruct]] = DiffusionResnetStruct
    transformer_struct_cls: tp.ClassVar[type[DiffusionTransformerStruct]] = DiffusionTransformerStruct
    # endregion

    parent: tp.Optional["UNetStruct"] = field(repr=False)
    # region attributes
    block_type: BlockType
    # endregion
    # region modules
    resnets: nn.ModuleList = field(repr=False)
    transformers: nn.ModuleList = field(repr=False)
    sampler: nn.Conv2d | None
    # endregion
    # region relative names
    resnets_rname: str
    transformers_rname: str
    sampler_rname: str
    # endregion
    # region absolute names
    resnets_name: str = field(init=False, repr=False)
    transformers_name: str = field(init=False, repr=False)
    sampler_name: str = field(init=False, repr=False)
    resnet_names: list[str] = field(init=False, repr=False)
    transformer_names: list[str] = field(init=False, repr=False)
    # endregion
    # region absolute keys
    sampler_key: str = field(init=False, repr=False)
    # endregion
    # region child structs
    resnet_structs: list[DiffusionResnetStruct] = field(init=False, repr=False)
    transformer_structs: list[DiffusionTransformerStruct] = field(init=False, repr=False)
    # endregion

    @property
    def downsample(self) -> nn.Conv2d | None:
        return self.sampler if self.is_downsample_block() else None

    @property
    def upsample(self) -> nn.Conv2d | None:
        return self.sampler if self.is_upsample_block() else None

    def __post_init__(self) -> None:
        super().__post_init__()
        if self.is_downsample_block():
            assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
            if self.parent is not None and isinstance(self.parent, UNetStruct):
                assert self.rname == f"{self.parent.down_blocks_rname}.{self.idx}"
        elif self.is_mid_block():
            assert len(self.resnets) == len(self.transformers) + 1 or len(self.transformers) == 0
            if self.parent is not None and isinstance(self.parent, UNetStruct):
                assert self.rname == self.parent.mid_block_name
                assert self.idx == 0
        else:
            assert self.is_upsample_block(), f"Unsupported block type: {self.block_type}"
            assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
            if self.parent is not None and isinstance(self.parent, UNetStruct):
                assert self.rname == f"{self.parent.up_blocks_rname}.{self.idx}"
        resnet_rnames = [f"{self.resnets_rname}.{idx}" for idx in range(len(self.resnets))]
        transformer_rnames = [f"{self.transformers_rname}.{idx}" for idx in range(len(self.transformers))]
        self.resnets_name = join_name(self.name, self.resnets_rname)
        self.transformers_name = join_name(self.name, self.transformers_rname)
        self.resnet_names = [join_name(self.name, rname) for rname in resnet_rnames]
        self.transformer_names = [join_name(self.name, rname) for rname in transformer_rnames]
        self.sampler_name = join_name(self.name, self.sampler_rname)
        self.sampler_key = join_name(self.key, self.sampler_rkey, sep="_")
        self.resnet_structs = [
            self.resnet_struct_cls.construct(
                resnet, parent=self, fname="resnet", rname=rname, rkey=self.resnet_rkey, idx=idx
            )
            for idx, (resnet, rname) in enumerate(zip(self.resnets, resnet_rnames, strict=True))
        ]
        self.transformer_structs = [
            self.transformer_struct_cls.construct(
                transformer, parent=self, fname="transformer", rname=rname, rkey=self.transformer_rkey, idx=idx
            )
            for idx, (transformer, rname) in enumerate(zip(self.transformers, transformer_rnames, strict=True))
        ]

    def is_downsample_block(self) -> bool:
        return self.block_type == self.BlockType.DOWN

    def is_mid_block(self) -> bool:
        return self.block_type == self.BlockType.MID

    def is_upsample_block(self) -> bool:
        return self.block_type == self.BlockType.UP

    def has_downsample(self) -> bool:
        return self.is_downsample_block() and self.sampler is not None

    def has_upsample(self) -> bool:
        return self.is_upsample_block() and self.sampler is not None

    def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
        for resnet in self.resnet_structs:
            yield from resnet.named_key_modules()
        for transformer in self.transformer_structs:
            yield from transformer.named_key_modules()
        if self.sampler is not None:
            yield self.sampler_key, self.sampler_name, self.sampler, self, "sampler"

    def iter_attention_structs(self) -> tp.Generator[DiffusionAttentionStruct, None, None]:
        for transformer in self.transformer_structs:
            yield from transformer.iter_attention_structs()

    def iter_transformer_block_structs(self) -> tp.Generator[DiffusionTransformerBlockStruct, None, None]:
        for transformer in self.transformer_structs:
            yield from transformer.iter_transformer_block_structs()

    @staticmethod
    def _default_construct(
        module: UNET_BLOCK_CLS,
        /,
        parent: tp.Optional["UNetStruct"] = None,
        fname: str = "",
        rname: str = "",
        rkey: str = "",
        idx: int = 0,
        **kwargs,
    ) -> "UNetBlockStruct":
        resnets, resnets_rname = module.resnets, "resnets"
        if isinstance(module, (DownBlock2D, CrossAttnDownBlock2D)):
            block_type = UNetBlockStruct.BlockType.DOWN
            if isinstance(module, CrossAttnDownBlock2D) and module.attentions is not None:
                transformers, transformers_rname = module.attentions, "attentions"
            else:
                transformers, transformers_rname = [], ""
            if module.downsamplers is None:
                sampler, sampler_rname = None, ""
            else:
                assert len(module.downsamplers) == 1
                downsampler = module.downsamplers[0]
                assert isinstance(downsampler, Downsample2D)
                sampler, sampler_rname = downsampler.conv, "downsamplers.0.conv"
                assert isinstance(sampler, nn.Conv2d)
        elif isinstance(module, (UNetMidBlock2D, UNetMidBlock2DCrossAttn)):
            block_type = UNetBlockStruct.BlockType.MID
            if (isinstance(module, UNetMidBlock2DCrossAttn) or module.add_attention) and module.attentions is not None:
                transformers, transformers_rname = module.attentions, "attentions"
            else:
                transformers, transformers_rname = [], ""
            sampler, sampler_rname = None, ""
        elif isinstance(module, (UpBlock2D, CrossAttnUpBlock2D)):
            block_type = UNetBlockStruct.BlockType.UP
            if isinstance(module, CrossAttnUpBlock2D) and module.attentions is not None:
                transformers, transformers_rname = module.attentions, "attentions"
            else:
                transformers, transformers_rname = [], ""
            if module.upsamplers is None:
                sampler, sampler_rname = None, ""
            else:
                assert len(module.upsamplers) == 1
                upsampler = module.upsamplers[0]
                assert isinstance(upsampler, Upsample2D)
                sampler, sampler_rname = upsampler.conv, "upsamplers.0.conv"
                assert isinstance(sampler, nn.Conv2d)
        else:
            raise NotImplementedError(f"Unsupported module type: {type(module)}")
        return UNetBlockStruct(
            module=module,
            parent=parent,
            fname=fname,
            idx=idx,
            rname=rname,
            rkey=rkey,
            block_type=block_type,
            resnets=resnets,
            transformers=transformers,
            sampler=sampler,
            resnets_rname=resnets_rname,
            transformers_rname=transformers_rname,
            sampler_rname=sampler_rname,
        )

    @classmethod
    def _get_default_key_map(cls) -> dict[str, set[str]]:
        """Get the default allowed keys."""
        key_map: dict[str, set[str]] = defaultdict(set)
        resnet_cls = cls.resnet_struct_cls
        resnet_key = resnet_rkey = cls.resnet_rkey
        resnet_key_map = resnet_cls._get_default_key_map()
        for rkey, keys in resnet_key_map.items():
            rkey = join_name(resnet_rkey, rkey, sep="_")
            for key in keys:
                key = join_name(resnet_key, key, sep="_")
                key_map[rkey].add(key)
                key_map[resnet_rkey].add(key)
        transformer_cls = cls.transformer_struct_cls
        transformer_key = transformer_rkey = cls.transformer_rkey
        transformer_key_map = transformer_cls._get_default_key_map()
        for rkey, keys in transformer_key_map.items():
            trkey = join_name(transformer_rkey, rkey, sep="_")
            for key in keys:
                key = join_name(transformer_key, key, sep="_")
                key_map[rkey].add(key)
                key_map[trkey].add(key)
        return {k: v for k, v in key_map.items() if v}


@dataclass(kw_only=True)
class UNetStruct(DiffusionModelStruct):
    # region relative keys
    input_embed_rkey: tp.ClassVar[str] = "input_embed"
    """hidden_states = input_embed(hidden_states), e.g., conv_in"""
    time_embed_rkey: tp.ClassVar[str] = "time_embed"
    """temb = time_embed(timesteps, hidden_states)"""
    add_time_embed_rkey: tp.ClassVar[str] = "time_embed"
    """add_temb = add_time_embed(timesteps, encoder_hidden_states)"""
    text_embed_rkey: tp.ClassVar[str] = "text_embed"
    """encoder_hidden_states = text_embed(encoder_hidden_states)"""
    norm_out_rkey: tp.ClassVar[str] = "output_embed"
    """hidden_states = norm_out(hidden_states), e.g., conv_norm_out"""
    proj_out_rkey: tp.ClassVar[str] = "output_embed"
    """hidden_states = output_embed(hidden_states), e.g., conv_out"""
    down_block_rkey: tp.ClassVar[str] = "down"
    mid_block_rkey: tp.ClassVar[str] = "mid"
    up_block_rkey: tp.ClassVar[str] = "up"
    down_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
    mid_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
    up_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
    # endregion

    # region child modules
    # region pre-block modules
    input_embed: nn.Conv2d
    time_embed: TimestepEmbedding
    """Time embedding"""
    add_time_embed: (
        TextTimeEmbedding
        | TextImageTimeEmbedding
        | TimestepEmbedding
        | ImageTimeEmbedding
        | ImageHintTimeEmbedding
        | None
    )
    """Additional time embedding"""
    text_embed: nn.Linear | ImageProjection | TextImageProjection | None
    """Text embedding"""
    # region post-block modules
    norm_out: nn.GroupNorm | None
    proj_out: nn.Conv2d
    # endregion
    # endregion
    down_blocks: nn.ModuleList = field(repr=False)
    mid_block: nn.Module = field(repr=False)
    up_blocks: nn.ModuleList = field(repr=False)
    # endregion
    # region relative names
    input_embed_rname: str
    time_embed_rname: str
    add_time_embed_rname: str
    text_embed_rname: str
    norm_out_rname: str
    proj_out_rname: str
    down_blocks_rname: str
    mid_block_rname: str
    up_blocks_rname: str
    # endregion
    # region absolute names
    input_embed_name: str = field(init=False, repr=False)
    time_embed_name: str = field(init=False, repr=False)
    add_time_embed_name: str = field(init=False, repr=False)
    text_embed_name: str = field(init=False, repr=False)
    norm_out_name: str = field(init=False, repr=False)
    proj_out_name: str = field(init=False, repr=False)
    down_blocks_name: str = field(init=False, repr=False)
    mid_block_name: str = field(init=False, repr=False)
    up_blocks_name: str = field(init=False, repr=False)
    down_block_names: list[str] = field(init=False, repr=False)
    up_block_names: list[str] = field(init=False, repr=False)
    # endregion
    # region absolute keys
    input_embed_key: str = field(init=False, repr=False)
    time_embed_key: str = field(init=False, repr=False)
    add_time_embed_key: str = field(init=False, repr=False)
    text_embed_key: str = field(init=False, repr=False)
    norm_out_key: str = field(init=False, repr=False)
    proj_out_key: str = field(init=False, repr=False)
    # endregion
    # region child structs
    down_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
    mid_block_struct: UNetBlockStruct = field(init=False, repr=False)
    up_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
    # endregion

    @property
    def num_down_blocks(self) -> int:
        return len(self.down_blocks)

    @property
    def num_up_blocks(self) -> int:
        return len(self.up_blocks)

    @property
    def num_blocks(self) -> int:
        return self.num_down_blocks + 1 + self.num_up_blocks

    @property
    def block_structs(self) -> list[UNetBlockStruct]:
        return [*self.down_block_structs, self.mid_block_struct, *self.up_block_structs]

    def __post_init__(self) -> None:
        super().__post_init__()
        down_block_rnames = [f"{self.down_blocks_rname}.{idx}" for idx in range(len(self.down_blocks))]
        up_block_rnames = [f"{self.up_blocks_rname}.{idx}" for idx in range(len(self.up_blocks))]
        self.down_blocks_name = join_name(self.name, self.down_blocks_rname)
        self.mid_block_name = join_name(self.name, self.mid_block_rname)
        self.up_blocks_name = join_name(self.name, self.up_blocks_rname)
        self.down_block_names = [join_name(self.name, rname) for rname in down_block_rnames]
        self.up_block_names = [join_name(self.name, rname) for rname in up_block_rnames]
        self.pre_module_structs = {}
        for fname in ("time_embed", "add_time_embed", "text_embed", "input_embed"):
            module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
            setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
            if module is not None or rname:
                setattr(self, f"{fname}_name", join_name(self.name, rname))
            else:
                setattr(self, f"{fname}_name", "")
            if module is not None:
                assert rname, f"rname of {fname} must not be empty"
                self.pre_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
                    module=module, parent=self, fname=fname, rname=rname, rkey=rkey
                )
        self.post_module_structs = {}
        for fname in ("norm_out", "proj_out"):
            module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
            setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
            if module is not None or rname:
                setattr(self, f"{fname}_name", join_name(self.name, rname))
            else:
                setattr(self, f"{fname}_name", "")
            if module is not None:
                self.post_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
                    module=module, parent=self, fname=fname, rname=rname, rkey=rkey
                )
        self.down_block_structs = [
            self.down_block_struct_cls.con
Download .txt
gitextract_0u4zluxv/

├── .gitignore
├── LICENSE
├── README.md
├── assets/
│   ├── diffusion/
│   │   └── .gitkeep
│   └── llm/
│       └── .gitkeep
├── deepcompressor/
│   ├── __init__.py
│   ├── app/
│   │   ├── __init__.py
│   │   ├── diffusion/
│   │   │   ├── __init__.py
│   │   │   ├── cache/
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── config.py
│   │   │   ├── dataset/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── base.py
│   │   │   │   ├── calib.py
│   │   │   │   ├── collect/
│   │   │   │   │   ├── calib.py
│   │   │   │   │   └── utils.py
│   │   │   │   └── data/
│   │   │   │       ├── COCO/
│   │   │   │       │   ├── COCO.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── DCI/
│   │   │   │       │   ├── DCI.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── MJHQ/
│   │   │   │       │   ├── MJHQ.py
│   │   │   │       │   └── __init__.py
│   │   │   │       ├── __init__.py
│   │   │   │       └── dump.py
│   │   │   ├── eval/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── config.py
│   │   │   │   └── metrics/
│   │   │   │       ├── __init__.py
│   │   │   │       ├── fid.py
│   │   │   │       ├── image_reward.py
│   │   │   │       ├── multimodal.py
│   │   │   │       ├── run.py
│   │   │   │       └── similarity.py
│   │   │   ├── nn/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── attention.py
│   │   │   │   ├── patch.py
│   │   │   │   └── struct.py
│   │   │   ├── pipeline/
│   │   │   │   ├── __init__.py
│   │   │   │   └── config.py
│   │   │   ├── ptq.py
│   │   │   ├── quant/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── activation.py
│   │   │   │   ├── config.py
│   │   │   │   ├── quantizer/
│   │   │   │   │   ├── __init__.py
│   │   │   │   │   ├── config.py
│   │   │   │   │   └── quantizer.py
│   │   │   │   ├── rotate.py
│   │   │   │   ├── smooth.py
│   │   │   │   ├── utils.py
│   │   │   │   └── weight.py
│   │   │   └── utils.py
│   │   └── llm/
│   │       ├── __init__.py
│   │       ├── cache/
│   │       │   ├── __init__.py
│   │       │   └── config.py
│   │       ├── config.py
│   │       ├── eval/
│   │       │   ├── __init__.py
│   │       │   ├── base.py
│   │       │   ├── config.py
│   │       │   ├── custom.py
│   │       │   ├── lm_eval.py
│   │       │   └── longbench/
│   │       │       ├── __init__.py
│   │       │       ├── eval.py
│   │       │       ├── metrics.py
│   │       │       └── task2prompt.json
│   │       ├── model/
│   │       │   ├── __init__.py
│   │       │   └── config.py
│   │       ├── nn/
│   │       │   ├── __init__.py
│   │       │   ├── patch.py
│   │       │   └── struct.py
│   │       ├── ptq.py
│   │       └── quant/
│   │           ├── __init__.py
│   │           ├── activation.py
│   │           ├── config.py
│   │           ├── dataset.py
│   │           ├── quantizer/
│   │           │   ├── __init__.py
│   │           │   ├── config.py
│   │           │   └── quantizer.py
│   │           ├── reorder.py
│   │           ├── rotate.py
│   │           ├── smooth.py
│   │           ├── utils.py
│   │           └── weight.py
│   ├── backend/
│   │   ├── __init__.py
│   │   ├── nunchaku/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   ├── convert_lora.py
│   │   │   └── utils.py
│   │   ├── qserve/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   └── utils.py
│   │   ├── tinychat/
│   │   │   ├── __init__.py
│   │   │   ├── convert.py
│   │   │   ├── csrc/
│   │   │   │   ├── load.py
│   │   │   │   ├── pybind.cpp
│   │   │   │   ├── quantization/
│   │   │   │   │   ├── dequantize.cuh
│   │   │   │   │   ├── gemm/
│   │   │   │   │   │   ├── gemm_cuda.cu
│   │   │   │   │   │   ├── gemm_cuda.h
│   │   │   │   │   │   └── semaphore.h
│   │   │   │   │   └── gemv/
│   │   │   │   │       ├── gemv_cuda.cu
│   │   │   │   │       └── gemv_cuda.h
│   │   │   │   └── utils.cuh
│   │   │   ├── linear.py
│   │   │   └── utils.py
│   │   └── utils.py
│   ├── calib/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── lowrank.py
│   │   │   ├── range.py
│   │   │   ├── reorder.py
│   │   │   ├── rotation.py
│   │   │   ├── search.py
│   │   │   └── smooth.py
│   │   ├── lowrank.py
│   │   ├── metric.py
│   │   ├── range.py
│   │   ├── reorder.py
│   │   ├── rotate.py
│   │   ├── search.py
│   │   └── smooth.py
│   ├── csrc/
│   │   ├── load.py
│   │   ├── pybind.cpp
│   │   └── quantize/
│   │       ├── quantize.cu
│   │       └── quantize.h
│   ├── data/
│   │   ├── __init__.py
│   │   ├── cache.py
│   │   ├── codebook.py
│   │   ├── common.py
│   │   ├── dtype.py
│   │   ├── range.py
│   │   ├── scale.py
│   │   ├── tensor.py
│   │   ├── utils/
│   │   │   ├── __init__.py
│   │   │   ├── dtype.py
│   │   │   ├── reshape.py
│   │   │   ├── scale.py
│   │   │   └── shape.py
│   │   └── zero.py
│   ├── dataset/
│   │   ├── __init__.py
│   │   ├── action.py
│   │   ├── cache.py
│   │   └── config.py
│   ├── nn/
│   │   ├── __init__.py
│   │   ├── patch/
│   │   │   ├── __init__.py
│   │   │   ├── conv.py
│   │   │   ├── linear.py
│   │   │   ├── lowrank.py
│   │   │   └── sdpa.py
│   │   └── struct/
│   │       ├── __init__.py
│   │       ├── attn.py
│   │       └── base.py
│   ├── quantizer/
│   │   ├── __init__.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── kernel.py
│   │   │   └── lowrank.py
│   │   ├── impl/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── info.py
│   │   │   ├── scale.py
│   │   │   ├── simple.py
│   │   │   └── ste.py
│   │   ├── kernel/
│   │   │   ├── __init__.py
│   │   │   ├── gptq.py
│   │   │   └── rtn.py
│   │   └── processor.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── common.py
│   │   ├── config/
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   ├── model.py
│   │   │   ├── output.py
│   │   │   └── path.py
│   │   ├── dataclass.py
│   │   ├── hooks/
│   │   │   ├── __init__.py
│   │   │   ├── branch.py
│   │   │   ├── hook.py
│   │   │   ├── packager.py
│   │   │   └── processor.py
│   │   ├── math/
│   │   │   ├── __init__.py
│   │   │   ├── functional.py
│   │   │   └── hadamard.py
│   │   ├── patch.py
│   │   └── tools/
│   │       ├── __init__.py
│   │       ├── logging.py
│   │       └── sys.py
│   └── version.py
├── environment.yml
├── examples/
│   ├── diffusion/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── configs/
│   │   │   ├── __default__.yaml
│   │   │   ├── collect/
│   │   │   │   └── qdiff.yaml
│   │   │   ├── lora/
│   │   │   │   ├── __default__.yaml
│   │   │   │   └── flux.1-dev/
│   │   │   │       ├── anime.yaml
│   │   │   │       ├── ghibsky.yaml
│   │   │   │       ├── realism.yaml
│   │   │   │       ├── sketch.yaml
│   │   │   │       └── yarn.yaml
│   │   │   ├── model/
│   │   │   │   ├── flux.1-dev.yaml
│   │   │   │   ├── flux.1-schnell.yaml
│   │   │   │   ├── pixart-sigma.yaml
│   │   │   │   └── sana-1.6b.yaml
│   │   │   ├── svdquant/
│   │   │   │   ├── __default__.yaml
│   │   │   │   ├── fast.yaml
│   │   │   │   ├── gptq.yaml
│   │   │   │   ├── int4.yaml
│   │   │   │   └── nvfp4.yaml
│   │   │   └── text/
│   │   │       ├── __default__.yaml
│   │   │       └── awq.yaml
│   │   ├── prompts/
│   │   │   ├── lora/
│   │   │   │   ├── anime.yaml
│   │   │   │   ├── ghibsky.yaml
│   │   │   │   ├── realism.yaml
│   │   │   │   ├── sketch.yaml
│   │   │   │   └── yarn.yaml
│   │   │   └── qdiff.yaml
│   │   └── scripts/
│   │       └── svdquant.sh
│   └── llm/
│       ├── .gitignore
│       ├── README.md
│       ├── configs/
│       │   ├── __default__.yaml
│       │   ├── awq.yaml
│       │   ├── gptq.yaml
│       │   ├── ooo.yaml
│       │   ├── qoq-g128.yaml
│       │   ├── qoq-gchn.yaml
│       │   ├── smoothquant-dynamic.yaml
│       │   └── smoothquant-static.yaml
│       └── scripts/
│           ├── awq.sh
│           ├── gptq.sh
│           ├── qoq.sh
│           └── smoothquant.sh
└── pyproject.toml
Download .txt
SYMBOL INDEX (1198 symbols across 124 files)

FILE: deepcompressor/app/diffusion/cache/config.py
  class DiffusionQuantCacheConfig (line 19) | class DiffusionQuantCacheConfig(BasePathConfig):
    method simplify_path (line 39) | def simplify_path(path: str, key_map: dict[str, set[str]]) -> str:
    method simplify (line 63) | def simplify(self, key_map: dict[str, set[str]]) -> tp.Self:
  class DiffusionPtqCacheConfig (line 70) | class DiffusionPtqCacheConfig:

FILE: deepcompressor/app/diffusion/config.py
  class DiffusionPtqRunConfig (line 34) | class DiffusionPtqRunConfig:
    method __post_init__ (line 76) | def __post_init__(self):
    method generate_default_dirname (line 158) | def generate_default_dirname(self) -> str:
    method get_parser (line 186) | def get_parser(cls) -> ConfigParser:

FILE: deepcompressor/app/diffusion/dataset/base.py
  class DiffusionDataset (line 18) | class DiffusionDataset(torch.utils.data.Dataset):
    method __init__ (line 23) | def __init__(self, path: str, num_samples: int = -1, seed: int = 0, ex...
    method __len__ (line 38) | def __len__(self) -> int:
    method __getitem__ (line 41) | def __getitem__(self, idx) -> dict[str, tp.Any]:
    method build_loader (line 65) | def build_loader(self, **kwargs) -> torch.utils.data.DataLoader:

FILE: deepcompressor/app/diffusion/dataset/calib.py
  class DiffusionCalibCacheLoaderConfig (line 44) | class DiffusionCalibCacheLoaderConfig(BaseDataLoaderConfig):
    method build_dataset (line 63) | def build_dataset(self) -> "DiffusionCalibDataset":
    method build_loader (line 67) | def build_loader(self) -> "DiffusionCalibCacheLoader":
  class DiffusionCalibDataset (line 72) | class DiffusionCalibDataset(DiffusionDataset):
    method __init__ (line 75) | def __init__(self, path: str, num_samples: int = -1, seed: int = 0) ->...
    method __len__ (line 81) | def __len__(self) -> int:
    method __getitem__ (line 84) | def __getitem__(self, idx) -> dict[str, tp.Any]:
  class DiffusionConcatCacheAction (line 88) | class DiffusionConcatCacheAction(ConcatCacheAction):
    method info (line 89) | def info(
  class DiffusionCalibCacheLoader (line 137) | class DiffusionCalibCacheLoader(BaseCalibCacheLoader):
    method __init__ (line 141) | def __init__(self, config: DiffusionCalibCacheLoaderConfig) -> None:
    method _init_cache (line 152) | def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
    method iter_samples (line 188) | def iter_samples(self) -> tp.Generator[ModuleForwardInput, None, None]:
    method _convert_layer_inputs (line 195) | def _convert_layer_inputs(
    method _convert_layer_outputs (line 239) | def _convert_layer_outputs(self, m: nn.Module, outputs: tp.Any) -> dic...
    method iter_layer_activations (line 259) | def iter_layer_activations(  # noqa: C901

FILE: deepcompressor/app/diffusion/dataset/collect/calib.py
  function process (line 21) | def process(x: torch.Tensor) -> torch.Tensor:
  function collect (line 26) | def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
  class CollectConfig (line 93) | class CollectConfig:

FILE: deepcompressor/app/diffusion/dataset/collect/utils.py
  class CollectHook (line 21) | class CollectHook:
    method __init__ (line 22) | def __init__(self, caches: list[dict[str, tp.Any]] = None, zero_redund...
    method __call__ (line 26) | def __call__(

FILE: deepcompressor/app/diffusion/dataset/data/COCO/COCO.py
  function hash_string_to_int (line 89) | def hash_string_to_int(s: str) -> int:
  class COCOConfig (line 97) | class COCOConfig(datasets.BuilderConfig):
    method __init__ (line 98) | def __init__(self, max_dataset_size: int = -1, return_gt: bool = False...
  class COCO (line 110) | class COCO(datasets.GeneratorBasedBuilder):
    method _info (line 123) | def _info(self):
    method _split_generators (line 132) | def _split_generators(self, dl_manager: datasets.download.DownloadMana...
    method _generate_examples (line 152) | def _generate_examples(

FILE: deepcompressor/app/diffusion/dataset/data/DCI/DCI.py
  class DCIConfig (line 34) | class DCIConfig(datasets.BuilderConfig):
    method __init__ (line 35) | def __init__(self, max_dataset_size: int = -1, return_gt: bool = False...
  class DCI (line 47) | class DCI(datasets.GeneratorBasedBuilder):
    method _info (line 54) | def _info(self):
    method _split_generators (line 70) | def _split_generators(self, dl_manager: datasets.download.DownloadMana...
    method _generate_examples (line 83) | def _generate_examples(self, meta_path: str, image_root: str):

FILE: deepcompressor/app/diffusion/dataset/data/MJHQ/MJHQ.py
  class MJHQConfig (line 36) | class MJHQConfig(datasets.BuilderConfig):
    method __init__ (line 37) | def __init__(self, max_dataset_size: int = -1, return_gt: bool = False...
  class DCI (line 49) | class DCI(datasets.GeneratorBasedBuilder):
    method _info (line 56) | def _info(self):
    method _split_generators (line 73) | def _split_generators(self, dl_manager: datasets.download.DownloadMana...
    method _generate_examples (line 82) | def _generate_examples(self, meta_path: str, image_root: str):

FILE: deepcompressor/app/diffusion/dataset/data/__init__.py
  function load_dataset_yaml (line 10) | def load_dataset_yaml(meta_path: str, max_dataset_size: int = -1, repeat...
  function get_dataset (line 30) | def get_dataset(

FILE: deepcompressor/app/diffusion/eval/config.py
  class DiffusionEvalConfig (line 29) | class DiffusionEvalConfig:
    method __post_init__ (line 103) | def __post_init__(self):
    method get_pipeline_kwargs (line 110) | def get_pipeline_kwargs(self) -> dict[str, tp.Any]:
    method _generate (line 124) | def _generate(
    method generate (line 190) | def generate(
    method evaluate (line 232) | def evaluate(

FILE: deepcompressor/app/diffusion/eval/metrics/__init__.py
  function compute_image_metrics (line 16) | def compute_image_metrics(
  function categorize_metrics (line 80) | def categorize_metrics(metrics: tuple[str, ...]) -> tuple[list[str], lis...

FILE: deepcompressor/app/diffusion/eval/metrics/fid.py
  function get_dataset_features (line 15) | def get_dataset_features(
  function get_fid_features (line 44) | def get_fid_features(
  function compute_fid (line 87) | def compute_fid(

FILE: deepcompressor/app/diffusion/eval/metrics/image_reward.py
  function compute_image_reward (line 10) | def compute_image_reward(

FILE: deepcompressor/app/diffusion/eval/metrics/multimodal.py
  class PromptImageDataset (line 16) | class PromptImageDataset(data.Dataset):
    method __init__ (line 17) | def __init__(self, ref_dataset: datasets.Dataset, gen_dirpath: str):
    method __len__ (line 22) | def __len__(self):
    method __getitem__ (line 25) | def __getitem__(self, idx: int):
  function compute_image_multimodal_metrics (line 33) | def compute_image_multimodal_metrics(

FILE: deepcompressor/app/diffusion/eval/metrics/similarity.py
  class MultiImageDataset (line 19) | class MultiImageDataset(data.Dataset):
    method __init__ (line 20) | def __init__(self, gen_dirpath: str, ref_dirpath_or_dataset: str | dat...
    method __len__ (line 37) | def __len__(self):
    method __getitem__ (line 40) | def __getitem__(self, idx: int):
  function compute_image_similarity_metrics (line 59) | def compute_image_similarity_metrics(

FILE: deepcompressor/app/diffusion/nn/attention.py
  function apply_flux_rope (line 24) | def apply_flux_rope(query, key, image_rotary_emb):
  class DiffusionAttentionProcessor (line 33) | class DiffusionAttentionProcessor(nn.Module):
    method __init__ (line 34) | def __init__(
    method __call__ (line 49) | def __call__(  # noqa: C901

FILE: deepcompressor/app/diffusion/nn/patch.py
  function replace_up_block_conv_with_concat_conv (line 20) | def replace_up_block_conv_with_concat_conv(model: nn.Module) -> None:
  function replace_fused_linear_with_concat_linear (line 57) | def replace_fused_linear_with_concat_linear(model: nn.Module) -> None:
  function shift_input_activations (line 73) | def shift_input_activations(model: nn.Module) -> None:
  function replace_attn_processor (line 109) | def replace_attn_processor(model: nn.Module) -> None:

FILE: deepcompressor/app/diffusion/nn/struct.py
  class DiffusionModuleStruct (line 120) | class DiffusionModuleStruct(BaseModuleStruct):
    method named_key_modules (line 121) | def named_key_modules(self) -> tp.Generator[tuple[str, str, nn.Module,...
  class DiffusionBlockStruct (line 133) | class DiffusionBlockStruct(BaseModuleStruct):
    method iter_attention_structs (line 135) | def iter_attention_structs(self) -> tp.Generator["DiffusionAttentionSt...
    method iter_transformer_block_structs (line 138) | def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTra...
  class DiffusionModelStruct (line 142) | class DiffusionModelStruct(DiffusionBlockStruct):
    method num_blocks (line 148) | def num_blocks(self) -> int: ...
    method block_structs (line 152) | def block_structs(self) -> list[DiffusionBlockStruct]: ...
    method get_prev_module_keys (line 155) | def get_prev_module_keys(self) -> tuple[str, ...]: ...
    method get_post_module_keys (line 158) | def get_post_module_keys(self) -> tuple[str, ...]: ...
    method _get_iter_block_activations_args (line 161) | def _get_iter_block_activations_args(
    method _get_iter_pre_module_activations_args (line 165) | def _get_iter_pre_module_activations_args(
    method _get_iter_post_module_activations_args (line 176) | def _get_iter_post_module_activations_args(
    method get_iter_layer_activations_args (line 187) | def get_iter_layer_activations_args(
    method named_key_modules (line 219) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 227) | def iter_attention_structs(self) -> tp.Generator["AttentionStruct", No...
    method iter_transformer_block_structs (line 231) | def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTra...
    method get_named_layers (line 235) | def get_named_layers(
    method _default_construct (line 249) | def _default_construct(
    method _get_default_key_map (line 270) | def _get_default_key_map(cls) -> dict[str, set[str]]:
    method _simplify_keys (line 284) | def _simplify_keys(keys: tp.Iterable[str], *, key_map: dict[str, set[s...
  class DiffusionAttentionStruct (line 314) | class DiffusionAttentionStruct(AttentionStruct):
    method filter_kwargs (line 319) | def filter_kwargs(self, kwargs: dict) -> dict:
    method _default_construct (line 336) | def _default_construct(
  class DiffusionFeedForwardStruct (line 422) | class DiffusionFeedForwardStruct(FeedForwardStruct):
    method up_proj (line 438) | def up_proj(self) -> nn.Linear:
    method down_proj (line 442) | def down_proj(self) -> nn.Linear:
    method up_proj_rname (line 446) | def up_proj_rname(self) -> str:
    method down_proj_rname (line 450) | def down_proj_rname(self) -> str:
    method up_proj_name (line 454) | def up_proj_name(self) -> str:
    method down_proj_name (line 458) | def down_proj_name(self) -> str:
    method __post_init__ (line 463) | def __post_init__(self) -> None:
    method _default_construct (line 470) | def _default_construct(
  class DiffusionTransformerBlockStruct (line 540) | class DiffusionTransformerBlockStruct(TransformerBlockStruct, DiffusionB...
    method __post_init__ (line 579) | def __post_init__(self) -> None:
    method named_key_modules (line 608) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method _default_construct (line 629) | def _default_construct(
    method _get_default_key_map (line 737) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class DiffusionTransformerStruct (line 779) | class DiffusionTransformerStruct(BaseTransformerStruct, DiffusionBlockSt...
    method num_blocks (line 814) | def num_blocks(self) -> int:
    method block_structs (line 818) | def block_structs(self) -> list[DiffusionBlockStruct]:
    method block_names (line 822) | def block_names(self) -> list[str]:
    method __post_init__ (line 827) | def __post_init__(self):
    method _default_construct (line 847) | def _default_construct(
    method _get_default_key_map (line 886) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class DiffusionResnetStruct (line 905) | class DiffusionResnetStruct(BaseModuleStruct):
    method __post_init__ (line 939) | def __post_init__(self):
    method named_key_modules (line 949) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method construct (line 959) | def construct(
    method _get_default_key_map (line 1044) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class UNetBlockStruct (line 1057) | class UNetBlockStruct(DiffusionBlockStruct):
    class BlockType (line 1058) | class BlockType(enum.StrEnum):
    method downsample (line 1101) | def downsample(self) -> nn.Conv2d | None:
    method upsample (line 1105) | def upsample(self) -> nn.Conv2d | None:
    method __post_init__ (line 1108) | def __post_init__(self) -> None:
    method is_downsample_block (line 1145) | def is_downsample_block(self) -> bool:
    method is_mid_block (line 1148) | def is_mid_block(self) -> bool:
    method is_upsample_block (line 1151) | def is_upsample_block(self) -> bool:
    method has_downsample (line 1154) | def has_downsample(self) -> bool:
    method has_upsample (line 1157) | def has_upsample(self) -> bool:
    method named_key_modules (line 1160) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 1168) | def iter_attention_structs(self) -> tp.Generator[DiffusionAttentionStr...
    method iter_transformer_block_structs (line 1172) | def iter_transformer_block_structs(self) -> tp.Generator[DiffusionTran...
    method _default_construct (line 1177) | def _default_construct(
    method _get_default_key_map (line 1242) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class UNetStruct (line 1267) | class UNetStruct(DiffusionModelStruct):
    method num_down_blocks (line 1353) | def num_down_blocks(self) -> int:
    method num_up_blocks (line 1357) | def num_up_blocks(self) -> int:
    method num_blocks (line 1361) | def num_blocks(self) -> int:
    method block_structs (line 1365) | def block_structs(self) -> list[UNetBlockStruct]:
    method __post_init__ (line 1368) | def __post_init__(self) -> None:
    method get_prev_module_keys (line 1418) | def get_prev_module_keys(self) -> tuple[str, ...]:
    method get_post_module_keys (line 1421) | def get_post_module_keys(self) -> tuple[str, ...]:
    method _get_iter_block_activations_args (line 1424) | def _get_iter_block_activations_args(
    method _default_construct (line 1464) | def _default_construct(
    method _get_default_key_map (line 1519) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class DiTStruct (line 1568) | class DiTStruct(DiffusionModelStruct, DiffusionTransformerStruct):
    method num_blocks (line 1618) | def num_blocks(self) -> int:
    method block_structs (line 1622) | def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
    method block_names (line 1626) | def block_names(self) -> list[str]:
    method __post_init__ (line 1629) | def __post_init__(self) -> None:
    method get_prev_module_keys (line 1654) | def get_prev_module_keys(self) -> tuple[str, ...]:
    method get_post_module_keys (line 1657) | def get_post_module_keys(self) -> tuple[str, ...]:
    method _get_iter_block_activations_args (line 1660) | def _get_iter_block_activations_args(
    method _default_construct (line 1685) | def _default_construct(
    method _get_default_key_map (line 1748) | def _get_default_key_map(cls) -> dict[str, set[str]]:
  class FluxStruct (line 1789) | class FluxStruct(DiTStruct):
    method num_blocks (line 1817) | def num_blocks(self) -> int:
    method block_structs (line 1821) | def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
    method block_names (line 1825) | def block_names(self) -> list[str]:
    method __post_init__ (line 1828) | def __post_init__(self) -> None:
    method _get_iter_block_activations_args (line 1849) | def _get_iter_block_activations_args(
    method _default_construct (line 1861) | def _default_construct(
    method _get_default_key_map (line 1906) | def _get_default_key_map(cls) -> dict[str, set[str]]:

FILE: deepcompressor/app/diffusion/pipeline/config.py
  class LoRAConfig (line 38) | class LoRAConfig:
  class DiffusionPipelineConfig (line 57) | class DiffusionPipelineConfig:
    method __post_init__ (line 95) | def __post_init__(self):
    method build (line 105) | def build(
    method extract_text_encoders (line 129) | def extract_text_encoders(
    method register_pipeline_factory (line 148) | def register_pipeline_factory(
    method register_text_extractor (line 174) | def register_text_extractor(
    method load_lora (line 202) | def load_lora(  # noqa: C901
    method _default_build (line 327) | def _default_build(
    method _default_extract_text_encoders (line 371) | def _default_extract_text_encoders(

FILE: deepcompressor/app/diffusion/ptq.py
  function ptq (line 27) | def ptq(  # noqa: C901
  function main (line 272) | def main(config: DiffusionPtqRunConfig, logging_level: int = tools.loggi...

FILE: deepcompressor/app/diffusion/quant/activation.py
  function quantize_diffusion_block_activations (line 30) | def quantize_diffusion_block_activations(  # noqa: C901
  function quantize_diffusion_activations (line 193) | def quantize_diffusion_activations(

FILE: deepcompressor/app/diffusion/quant/config.py
  class DiffusionQuantConfig (line 30) | class DiffusionQuantConfig(DiffusionModuleQuantizerConfig):
    method __post_init__ (line 53) | def __post_init__(self) -> None:  # noqa: C901
    method enabled_rotation (line 81) | def enabled_rotation(self) -> bool:
    method enabled_smooth (line 86) | def enabled_smooth(self) -> bool:
    method enabled_smooth_proj (line 91) | def enabled_smooth_proj(self) -> bool:
    method enabled_smooth_attn (line 96) | def enabled_smooth_attn(self) -> bool:
    method needs_acts_quantizer_cache (line 101) | def needs_acts_quantizer_cache(self) -> bool:
    method generate_calib_dirname (line 109) | def generate_calib_dirname(self) -> str:
    method generate_cache_dirpath (line 126) | def generate_cache_dirpath(
    method generate_default_dirname (line 165) | def generate_default_dirname(self) -> str:  # noqa: C901
    method set_key_map (line 415) | def set_key_map(cls, key_map: dict[str, set[str]]) -> None:
    method organize (line 423) | def organize(self) -> dict[str, bool]:  # noqa: C901

FILE: deepcompressor/app/diffusion/quant/quantizer/config.py
  class DiffusionGPTQConfig (line 25) | class DiffusionGPTQConfig(SkipBasedConfig, QuantGptqConfig):
  class DiffusionQuantizerConfig (line 45) | class DiffusionQuantizerConfig(QuantizerConfig):
    method __post_init__ (line 72) | def __post_init__(self) -> None:
    method enabled_gptq (line 87) | def enabled_gptq(self) -> bool:
    method enabled_low_rank (line 92) | def enabled_low_rank(self) -> bool:
    method enabled_calib_range (line 97) | def enabled_calib_range(self) -> bool:
    method generate_calib_dirname (line 101) | def generate_calib_dirname(self) -> str:
  class SkipBasedDiffusionQuantizerConfig (line 121) | class SkipBasedDiffusionQuantizerConfig(SkipBasedConfig, DiffusionQuanti...
    method __post_init__ (line 145) | def __post_init__(self) -> None:
  class DiffusionWeightQuantizerConfig (line 153) | class DiffusionWeightQuantizerConfig(SkipBasedDiffusionQuantizerConfig):
    method needs_calib_data (line 176) | def needs_calib_data(self) -> bool:
  class DiffusionActivationQuantizerConfig (line 182) | class DiffusionActivationQuantizerConfig(SkipBasedDiffusionQuantizerConf...
    method needs_calib_data (line 209) | def needs_calib_data(self) -> bool:
    method generate_dirnames (line 212) | def generate_dirnames(
    method for_unsigned (line 243) | def for_unsigned(self) -> "DiffusionActivationQuantizerConfig":
  class DiffusionExtraWeightQuantizerConfig (line 267) | class DiffusionExtraWeightQuantizerConfig(IncludeBasedConfig, DiffusionQ...
    method needs_calib_data (line 293) | def needs_calib_data(self) -> bool:
  class DiffusionModuleQuantizerConfig (line 299) | class DiffusionModuleQuantizerConfig(EnableConfig):
    method is_enabled (line 317) | def is_enabled(self):
    method enabled_wgts (line 321) | def enabled_wgts(self) -> bool:
    method enabled_ipts (line 326) | def enabled_ipts(self) -> bool:
    method enabled_opts (line 331) | def enabled_opts(self) -> bool:
    method enabled_extra_wgts (line 336) | def enabled_extra_wgts(self) -> bool:
    method __post_init__ (line 340) | def __post_init__(self) -> None:
    method generate_dirnames (line 354) | def generate_dirnames(
    method generate_calib_dirname (line 394) | def generate_calib_dirname(self) -> str:

FILE: deepcompressor/app/diffusion/quant/quantizer/quantizer.py
  class DiffusionQuantizer (line 29) | class DiffusionQuantizer(Quantizer):
    method __post_init__ (line 73) | def __post_init__(self) -> None:
    method calibrate_dynamic_range (line 77) | def calibrate_dynamic_range(
  class DiffusionWeightQuantizer (line 144) | class DiffusionWeightQuantizer(DiffusionQuantizer):
    method calibrate_dynamic_range (line 185) | def calibrate_dynamic_range(
    method calibrate_low_rank (line 234) | def calibrate_low_rank(
  class DiffusionActivationQuantizer (line 267) | class DiffusionActivationQuantizer(DiffusionQuantizer):
    method __post_init__ (line 304) | def __post_init__(self) -> None:

FILE: deepcompressor/app/diffusion/quant/rotate.py
  function rotate_diffusion (line 23) | def rotate_diffusion(  # noqa: C901

FILE: deepcompressor/app/diffusion/quant/smooth.py
  function smooth_diffusion_attention (line 30) | def smooth_diffusion_attention(
  function smooth_diffusion_qkv_proj (line 46) | def smooth_diffusion_qkv_proj(
  function smooth_diffusion_out_proj (line 135) | def smooth_diffusion_out_proj(  # noqa: C901
  function smooth_diffusion_up_proj (line 236) | def smooth_diffusion_up_proj(
  function smooth_diffusion_down_proj (line 280) | def smooth_diffusion_down_proj(
  function smooth_diffusion_parallel_qkv_up_proj (line 317) | def smooth_diffusion_parallel_qkv_up_proj(
  function smooth_diffusion_sequential_transformer_block (line 404) | def smooth_diffusion_sequential_transformer_block(
  function smooth_diffusion_parallel_transformer_block (line 448) | def smooth_diffusion_parallel_transformer_block(
  function smooth_diffusion_module (line 487) | def smooth_diffusion_module(
  function smooth_diffusion_layer (line 529) | def smooth_diffusion_layer(
  function smooth_diffusion (line 601) | def smooth_diffusion(

FILE: deepcompressor/app/diffusion/quant/utils.py
  function wrap_joint_attn (line 12) | def wrap_joint_attn(attn: nn.Module, /, *, indexes: int | tuple[int, ......
  function get_needs_inputs_fn (line 28) | def get_needs_inputs_fn(
  function get_needs_outputs_fn (line 82) | def get_needs_outputs_fn(

FILE: deepcompressor/app/diffusion/quant/weight.py
  function calibrate_diffusion_block_low_rank_branch (line 25) | def calibrate_diffusion_block_low_rank_branch(  # noqa: C901
  function update_diffusion_block_weight_quantizer_state_dict (line 151) | def update_diffusion_block_weight_quantizer_state_dict(
  function quantize_diffusion_block_weights (line 217) | def quantize_diffusion_block_weights(
  function quantize_diffusion_weights (line 296) | def quantize_diffusion_weights(
  function load_diffusion_weights_state_dict (line 436) | def load_diffusion_weights_state_dict(

FILE: deepcompressor/app/diffusion/utils.py
  function update_mask (line 13) | def update_mask(mask: np.ndarray, x: int, y: int, radius: int | float):
  function generate_mask (line 23) | def generate_mask(
  function center_crop_and_resize (line 46) | def center_crop_and_resize(image: Image.Image, target_size: int | tuple[...
  function get_control (line 70) | def get_control(  # noqa: C901

FILE: deepcompressor/app/llm/cache/config.py
  class LlmQuantCacheConfig (line 15) | class LlmQuantCacheConfig(BasePathConfig):
  class LlmCacheConfig (line 40) | class LlmCacheConfig:

FILE: deepcompressor/app/llm/config.py
  class LlmPtqRunConfig (line 33) | class LlmPtqRunConfig:
    method __post_init__ (line 70) | def __post_init__(self):  # noqa: C901
    method get_parser (line 119) | def get_parser(cls) -> ConfigParser:

FILE: deepcompressor/app/llm/eval/base.py
  class LlmEvaluatorBase (line 11) | class LlmEvaluatorBase(ABC):
    method __init__ (line 12) | def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokeni...
    method filter_tasks (line 16) | def filter_tasks(self, tasks: list[str]) -> list[str]:
    method evaluate (line 21) | def evaluate(self, tasks: list[str], **kwargs) -> dict[str, dict[str, ...

FILE: deepcompressor/app/llm/eval/config.py
  class LlmEvalConfig (line 25) | class LlmEvalConfig:
    method __post_init__ (line 62) | def __post_init__(self):
    method evaluate (line 74) | def evaluate(
    method make_table (line 171) | def make_table(rst: dict[str, dict[tp.Any, dict[str, tp.Any]]]) -> str:
  function get_max_seq_length (line 207) | def get_max_seq_length(model: PreTrainedModel, tokenizer: PreTrainedToke...

FILE: deepcompressor/app/llm/eval/custom.py
  class LlmCustomEvaluator (line 17) | class LlmCustomEvaluator(LlmEvaluatorBase):
    method filter_tasks (line 18) | def filter_tasks(self, tasks: list[str]) -> list[str]:
    method evaluate (line 22) | def evaluate(
  function _eval_ppl_with_gptq_evaluator (line 45) | def _eval_ppl_with_gptq_evaluator(

FILE: deepcompressor/app/llm/eval/lm_eval.py
  class LmevalEvaluator (line 13) | class LmevalEvaluator(LlmEvaluatorBase):
    method __init__ (line 14) | def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokeni...
    method filter_tasks (line 18) | def filter_tasks(self, tasks: list[str]) -> list[str]:
    method evaluate (line 22) | def evaluate(

FILE: deepcompressor/app/llm/eval/longbench/eval.py
  class LongbenchEvaluator (line 33) | class LongbenchEvaluator(LlmEvaluatorBase):
    method __init__ (line 60) | def __init__(
    method filter_tasks (line 80) | def filter_tasks(self, tasks: list[str]) -> list[str]:
    method evaluate (line 88) | def evaluate(self, tasks: list[str], max_length: int, **kwargs) -> dic...
    method predict (line 161) | def predict(
    method build_chat (line 218) | def build_chat(self, prompt):
    method post_process (line 224) | def post_process(self, response: str) -> str:
  class LongbenchScorer (line 242) | class LongbenchScorer:
    method score (line 268) | def score(
    method scorer_e (line 307) | def scorer_e(

FILE: deepcompressor/app/llm/eval/longbench/metrics.py
  function normalize_answer (line 24) | def normalize_answer(s: str) -> str:
  function normalize_zh_answer (line 40) | def normalize_zh_answer(s: str) -> str:
  function count_score (line 56) | def count_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function retrieval_score (line 66) | def retrieval_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function retrieval_zh_score (line 78) | def retrieval_zh_score(prediction: str, ground_truth: str, **kwargs) -> ...
  function code_sim_score (line 90) | def code_sim_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function classification_score (line 100) | def classification_score(prediction: str, ground_truth: str, **kwargs) -...
  function rouge_score (line 109) | def rouge_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function rouge_zh_score (line 117) | def rouge_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function f1_score (line 123) | def f1_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function qa_f1_score (line 133) | def qa_f1_score(prediction: str, ground_truth: str, **kwargs) -> float:
  function qa_f1_zh_score (line 141) | def qa_f1_zh_score(prediction: str, ground_truth: str, **kwargs) -> float:

FILE: deepcompressor/app/llm/model/config.py
  class LlmModelConfig (line 21) | class LlmModelConfig(BaseModelConfig):
    method __post_init__ (line 56) | def __post_init__(self):
    method build (line 103) | def build(self) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
    method _default_build (line 125) | def _default_build(path: str, **kwargs) -> tuple[PreTrainedModel, PreT...
    method register_model_factory (line 149) | def register_model_factory(

FILE: deepcompressor/app/llm/nn/patch.py
  function rotate_half (line 17) | def rotate_half(x):
  function update_rotary_cos_sin (line 24) | def update_rotary_cos_sin(
  class RotaryEmbedding (line 59) | class RotaryEmbedding(nn.Module):
    method __init__ (line 62) | def __init__(self) -> None:
    method forward (line 66) | def forward(
  function apply_rotary_pos_emb (line 90) | def apply_rotary_pos_emb(
  function patch_attention (line 140) | def patch_attention(model: nn.Module) -> nn.Module:
  function gemma_rms_norm_forward (line 176) | def gemma_rms_norm_forward(self: GemmaRMSNorm | Gemma2RMSNorm, x: torch....
  function patch_gemma_rms_norm (line 186) | def patch_gemma_rms_norm(model: nn.Module) -> nn.Module:

FILE: deepcompressor/app/llm/nn/struct.py
  class LlmTransformerBlockConfigStruct (line 115) | class LlmTransformerBlockConfigStruct(FeedForwardConfigStruct, Attention...
  class LlmTransformerConfigStruct (line 151) | class LlmTransformerConfigStruct(LlmTransformerBlockConfigStruct):
  class LlmConfigStruct (line 192) | class LlmConfigStruct(LlmTransformerConfigStruct):
  class LlmSelfAttentionStruct (line 234) | class LlmSelfAttentionStruct(SelfAttentionStruct):
    method filter_kwargs (line 246) | def filter_kwargs(self, kwargs: dict) -> dict:
    method _default_construct (line 251) | def _default_construct(
  class LlmFeedForwardStruct (line 345) | class LlmFeedForwardStruct(FeedForwardStruct):
    method _default_construct (line 351) | def _default_construct(
  class LlmTransformerBlockStruct (line 449) | class LlmTransformerBlockStruct(TransformerBlockStruct):
    method pre_attn_norm (line 487) | def pre_attn_norm(self) -> nn.LayerNorm | None:
    method attn (line 491) | def attn(self) -> nn.Module:
    method post_attn_norm (line 495) | def post_attn_norm(self) -> nn.LayerNorm | None:
    method pre_attn_norm_rname (line 499) | def pre_attn_norm_rname(self) -> str:
    method attn_rname (line 503) | def attn_rname(self) -> str:
    method post_attn_norm_rname (line 507) | def post_attn_norm_rname(self) -> str:
    method pre_attn_norm_name (line 511) | def pre_attn_norm_name(self) -> str:
    method attn_name (line 515) | def attn_name(self) -> str:
    method post_attn_norm_name (line 519) | def post_attn_norm_name(self) -> str:
    method attn_struct (line 523) | def attn_struct(self) -> LlmSelfAttentionStruct:
    method __post_init__ (line 528) | def __post_init__(self):
    method _default_construct (line 545) | def _default_construct(
  class LlmTransformerStruct (line 613) | class LlmTransformerStruct(BaseTransformerStruct):
    method num_blocks (line 651) | def num_blocks(self) -> int:
    method block_structs (line 656) | def block_structs(self) -> list[LlmTransformerBlockStruct]:
    method block_names (line 660) | def block_names(self) -> list[str]:
    method __post_init__ (line 666) | def __post_init__(self) -> None:
    method get_iter_layer_activations_args (line 702) | def get_iter_layer_activations_args(
    method _default_construct (line 721) | def _default_construct(
  class LlmModelStruct (line 778) | class LlmModelStruct(BaseModuleStruct):
    method __post_init__ (line 808) | def __post_init__(self) -> None:
    method named_key_modules (line 820) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 825) | def iter_attention_structs(self) -> tp.Generator[LlmSelfAttentionStruc...
    method iter_transformer_block_structs (line 828) | def iter_transformer_block_structs(self) -> tp.Generator[LlmTransforme...
    method get_iter_layer_activations_args (line 831) | def get_iter_layer_activations_args(
    method _default_construct (line 850) | def _default_construct(

FILE: deepcompressor/app/llm/ptq.py
  function ptq (line 22) | def ptq(  # noqa: C901
  function main (line 318) | def main(config: LlmPtqRunConfig, logging_level: int = tools.logging.DEB...

FILE: deepcompressor/app/llm/quant/activation.py
  function quantize_llm_layer_activations (line 25) | def quantize_llm_layer_activations(  # noqa: C901
  function quantize_llm_activations (line 196) | def quantize_llm_activations(

FILE: deepcompressor/app/llm/quant/config.py
  class LlmQuantConfig (line 31) | class LlmQuantConfig(LlmModuleQuantizerConfig):
    method __post_init__ (line 59) | def __post_init__(self) -> None:  # noqa: C901
    method enabled_smooth (line 95) | def enabled_smooth(self) -> bool:
    method enabled_smooth_proj (line 100) | def enabled_smooth_proj(self) -> bool:
    method enabled_smooth_attn (line 105) | def enabled_smooth_attn(self) -> bool:
    method enabled_reorder (line 110) | def enabled_reorder(self) -> bool:
    method enabled_rotation (line 115) | def enabled_rotation(self) -> bool:
    method needs_acts_quantizer_cache (line 120) | def needs_acts_quantizer_cache(self) -> bool:
    method generate_calib_dirname (line 128) | def generate_calib_dirname(self) -> str:
    method generate_default_dirname (line 149) | def generate_default_dirname(self) -> str:  # noqa: C901
    method generate_cache_dirpath (line 329) | def generate_cache_dirpath(

FILE: deepcompressor/app/llm/quant/dataset.py
  class LlmCalibDataLoaderConfig (line 33) | class LlmCalibDataLoaderConfig(BaseDataLoaderConfig):
    method __post_init__ (line 60) | def __post_init__(self) -> None:
    method generate_dirnames (line 68) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method build_dataset (line 73) | def build_dataset(self, tokenizer: PreTrainedTokenizer) -> "LlmCalibDa...
    method build_loader (line 94) | def build_loader(self, tokenizer: PreTrainedTokenizer) -> "LlmCalibCac...
  class LlmCalibDataset (line 108) | class LlmCalibDataset(torch.utils.data.Dataset):
    method __init__ (line 111) | def __init__(
    method __len__ (line 158) | def __len__(self) -> int:
    method __getitem__ (line 161) | def __getitem__(self, idx: int) -> torch.Tensor:
  class LlmCalibCacheLoader (line 165) | class LlmCalibCacheLoader(BaseCalibCacheLoader):
    method __init__ (line 171) | def __init__(self, config: LlmCalibDataLoaderConfig, tokenizer: PreTra...
    method _init_cache (line 184) | def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
    method _convert_layer_inputs (line 207) | def _convert_layer_inputs(
    method iter_samples (line 231) | def iter_samples(self) -> tp.Generator[ModuleForwardInput, None, None]:
    method iter_layer_activations (line 248) | def iter_layer_activations(  # noqa: C901

FILE: deepcompressor/app/llm/quant/quantizer/config.py
  class LlmQuantizerConfig (line 21) | class LlmQuantizerConfig(SkipBasedConfig, ProgressiveQuantizerConfig):
    method __post_init__ (line 53) | def __post_init__(self) -> None:
    method enabled_gptq (line 63) | def enabled_gptq(self) -> bool:
    method enabled_calib_range (line 68) | def enabled_calib_range(self) -> bool:
    method needs_calib_data (line 73) | def needs_calib_data(self) -> bool:
    method generate_calib_dirname (line 76) | def generate_calib_dirname(self) -> str:
  class LlmWeightQuantizerConfig (line 94) | class LlmWeightQuantizerConfig(LlmQuantizerConfig):
  class LlmActivationQuantizerConfig (line 125) | class LlmActivationQuantizerConfig(LlmQuantizerConfig):
  class LlmModuleQuantizerConfig (line 153) | class LlmModuleQuantizerConfig(EnableConfig):
    method is_enabled (line 169) | def is_enabled(self) -> bool:
    method enabled_wgts (line 174) | def enabled_wgts(self) -> bool:
    method enabled_ipts (line 179) | def enabled_ipts(self) -> bool:
    method enabled_opts (line 184) | def enabled_opts(self) -> bool:
    method generate_dirnames (line 188) | def generate_dirnames(
    method generate_calib_dirname (line 225) | def generate_calib_dirname(self) -> str:

FILE: deepcompressor/app/llm/quant/quantizer/quantizer.py
  class LlmQuantizer (line 23) | class LlmQuantizer(Quantizer):
    method __post_init__ (line 62) | def __post_init__(self) -> None:
    method calibrate_dynamic_range (line 65) | def calibrate_dynamic_range(
  class LlmWeightQuantizer (line 132) | class LlmWeightQuantizer(LlmQuantizer):
    method calibrate_dynamic_range (line 167) | def calibrate_dynamic_range(
  class LlmActivationQuantizer (line 218) | class LlmActivationQuantizer(LlmQuantizer):
    method __post_init__ (line 256) | def __post_init__(self) -> None:

FILE: deepcompressor/app/llm/quant/reorder.py
  function _extend_params_ (line 25) | def _extend_params_(
  function reorder_llm_layer (line 49) | def reorder_llm_layer(  # noqa: C901
  function reorder_llm (line 240) | def reorder_llm(  # noqa: C901

FILE: deepcompressor/app/llm/quant/rotate.py
  function rotate_llm (line 26) | def rotate_llm(  # noqa: C901

FILE: deepcompressor/app/llm/quant/smooth.py
  function smooth_llm_layer (line 24) | def smooth_llm_layer(  # noqa: C901
  function smooth_llm (line 171) | def smooth_llm(

FILE: deepcompressor/app/llm/quant/utils.py
  function get_needs_inputs_fn (line 14) | def get_needs_inputs_fn(model: LlmModelStruct, config: LlmModuleQuantize...
  function get_needs_outputs_fn (line 61) | def get_needs_outputs_fn(

FILE: deepcompressor/app/llm/quant/weight.py
  function quantize_llm_layer_weights (line 25) | def quantize_llm_layer_weights(  # noqa: C901
  function quantize_llm_weights (line 109) | def quantize_llm_weights(

FILE: deepcompressor/backend/nunchaku/convert.py
  function convert_to_nunchaku_w4x4y16_linear_state_dict (line 13) | def convert_to_nunchaku_w4x4y16_linear_state_dict(
  function convert_to_nunchaku_w4x16_adanorm_single_state_dict (line 75) | def convert_to_nunchaku_w4x16_adanorm_single_state_dict(
  function convert_to_nunchaku_w4x16_adanorm_zero_state_dict (line 92) | def convert_to_nunchaku_w4x16_adanorm_zero_state_dict(
  function update_state_dict (line 109) | def update_state_dict(
  function convert_to_nunchaku_transformer_block_state_dict (line 119) | def convert_to_nunchaku_transformer_block_state_dict(
  function convert_to_nunchaku_flux_single_transformer_block_state_dict (line 221) | def convert_to_nunchaku_flux_single_transformer_block_state_dict(
  function convert_to_nunchaku_flux_transformer_block_state_dict (line 272) | def convert_to_nunchaku_flux_transformer_block_state_dict(
  function convert_to_nunchaku_flux_state_dicts (line 347) | def convert_to_nunchaku_flux_state_dicts(

FILE: deepcompressor/backend/nunchaku/convert_lora.py
  function reorder_adanorm_lora_up (line 16) | def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch...
  function convert_to_nunchaku_transformer_block_lowrank_dict (line 22) | def convert_to_nunchaku_transformer_block_lowrank_dict(  # noqa: C901
  function convert_to_nunchaku_flux_single_transformer_block_lowrank_dict (line 146) | def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
  function convert_to_nunchaku_flux_transformer_block_lowrank_dict (line 193) | def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
  function convert_to_nunchaku_flux_lowrank_dict (line 237) | def convert_to_nunchaku_flux_lowrank_dict(

FILE: deepcompressor/backend/nunchaku/utils.py
  class NunchakuWeightPacker (line 16) | class NunchakuWeightPacker(MmaWeightPackerBase):
    method __init__ (line 17) | def __init__(self, bits: int, warp_n: int = 128):
    method pack_weight (line 21) | def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
    method pack_scale (line 62) | def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Te...
    method pack_micro_scale (line 109) | def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> to...
    method pack_lowrank_weight (line 153) | def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> tor...
    method unpack_lowrank_weight (line 184) | def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> t...
    method check_if_micro_scale (line 216) | def check_if_micro_scale(self, group_size: int) -> bool:
    method pad_weight (line 219) | def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
    method pad_scale (line 223) | def pad_scale(self, scale: torch.Tensor, group_size: int) -> torch.Ten...
    method pad_lowrank_weight (line 234) | def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torc...
  function convert_to_nunchaku_w4x4y16_linear_weight (line 239) | def convert_to_nunchaku_w4x4y16_linear_weight(
  function convert_to_nunchaku_w8x8y16_linear_weight (line 318) | def convert_to_nunchaku_w8x8y16_linear_weight(
  function convert_to_nunchaku_w4x16_linear_weight (line 341) | def convert_to_nunchaku_w4x16_linear_weight(

FILE: deepcompressor/backend/qserve/convert.py
  function convert_to_qserve_w4x8y16_linear_state_dict (line 15) | def convert_to_qserve_w4x8y16_linear_state_dict(
  function convert_to_qserve_w8x8y16_linear_state_dict (line 58) | def convert_to_qserve_w8x8y16_linear_state_dict(
  function convert_to_qserve_state_dict (line 83) | def convert_to_qserve_state_dict(

FILE: deepcompressor/backend/qserve/utils.py
  class QServePacker (line 11) | class QServePacker(MmaWeightPackerBase):
    method __init__ (line 12) | def __init__(self):
    method pack_weight (line 18) | def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
    method pack_scale (line 46) | def pack_scale(
  function convert_to_qserve_w4x8y16_linear_weight (line 64) | def convert_to_qserve_w4x8y16_linear_weight(
  function convert_to_qserve_w8x8y16_linear_weight (line 182) | def convert_to_qserve_w8x8y16_linear_weight(

FILE: deepcompressor/backend/tinychat/convert.py
  function convert_to_tinychat_w4x16y16_linear_state_dict (line 14) | def convert_to_tinychat_w4x16y16_linear_state_dict(
  function convert_to_tinychat_state_dict (line 50) | def convert_to_tinychat_state_dict(

FILE: deepcompressor/backend/tinychat/csrc/pybind.cpp
  function PYBIND11_MODULE (line 6) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

FILE: deepcompressor/backend/tinychat/csrc/quantization/gemm/semaphore.h
  function class (line 44) | class Semaphore

FILE: deepcompressor/backend/tinychat/linear.py
  class W4Linear (line 21) | class W4Linear(nn.Module):
    method __init__ (line 22) | def __init__(
    method weight_bits (line 71) | def weight_bits(self) -> int:
    method interleave (line 75) | def interleave(self) -> int:
    method forward (line 79) | def forward(self, x):
    method from_linear (line 97) | def from_linear(
    method extra_repr (line 174) | def extra_repr(self) -> str:

FILE: deepcompressor/backend/tinychat/utils.py
  function ceil_num_groups (line 11) | def ceil_num_groups(in_features: int, group_size: int, weight_bits: int ...
  function pack_w4 (line 45) | def pack_w4(weight: torch.Tensor) -> torch.Tensor:
  function convert_to_tinychat_w4x16y16_linear_weight (line 56) | def convert_to_tinychat_w4x16y16_linear_weight(

FILE: deepcompressor/backend/utils.py
  function ceil_divide (line 12) | def ceil_divide(x: int, divisor: int) -> int:
  function pad (line 28) | def pad(
  function load_state_dict_in_safetensors (line 55) | def load_state_dict_in_safetensors(
  function fp_quantize (line 81) | def fp_quantize(x: torch.Tensor, codebook: torch.Tensor | None = None) -...
  class MmaWeightPackerBase (line 91) | class MmaWeightPackerBase:
    method __init__ (line 92) | def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k:...
    method get_view_shape (line 143) | def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, ...

FILE: deepcompressor/calib/config/lowrank.py
  class QuantLowRankCalibConfig (line 18) | class QuantLowRankCalibConfig(SearchBasedCalibConfig, QuantLowRankConfig):
    method __post_init__ (line 51) | def __post_init__(self):
    method generate_dirnames (line 58) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class SkipBasedQuantLowRankCalibConfig (line 80) | class SkipBasedQuantLowRankCalibConfig(SkipBasedConfig, QuantLowRankCali...

FILE: deepcompressor/calib/config/range.py
  class DynamicRangeCalibConfig (line 17) | class DynamicRangeCalibConfig(SearchBasedCalibConfig):
    method get_linear_ratios (line 59) | def get_linear_ratios(self) -> list[float]:
    method get_ratios (line 73) | def get_ratios(self) -> list[list[float]]:
    method generate_dirnames (line 87) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class SkipBasedDynamicRangeCalibConfig (line 115) | class SkipBasedDynamicRangeCalibConfig(SkipBasedConfig, DynamicRangeCali...

FILE: deepcompressor/calib/config/reorder.py
  class ChannelOrderCalibConfig (line 22) | class ChannelOrderCalibConfig(SearchBasedCalibConfig):
    class ChannelMetric (line 46) | class ChannelMetric(enum.Enum):
    class ChannelIndex (line 59) | class ChannelIndex(enum.Enum):
    method __post_init__ (line 76) | def __post_init__(self) -> None:
    method generate_dirnames (line 81) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class SkipBasedChannelOrderConfig (line 114) | class SkipBasedChannelOrderConfig(SkipBasedConfig, ChannelOrderCalibConf...

FILE: deepcompressor/calib/config/rotation.py
  class QuantRotationConfig (line 16) | class QuantRotationConfig:
    method __post_init__ (line 36) | def __post_init__(self) -> None:
    method generate_dirnames (line 45) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method update_get_arguments (line 57) | def update_get_arguments(
    method update_from_dict (line 77) | def update_from_dict(

FILE: deepcompressor/calib/config/search.py
  class SearchBasedCalibStrategy (line 19) | class SearchBasedCalibStrategy(enum.Enum):
  class SearchBasedCalibGranularity (line 30) | class SearchBasedCalibGranularity(enum.Enum):
  class SearchBasedCalibObjective (line 38) | class SearchBasedCalibObjective(enum.Enum):
  class SearchBasedCalibConfig (line 51) | class SearchBasedCalibConfig:
    method __post_init__ (line 88) | def __post_init__(self) -> None:
    method needs_search (line 106) | def needs_search(self) -> bool:
    method generate_dirnames (line 110) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:

FILE: deepcompressor/calib/config/smooth.py
  class SmoothSpanMode (line 28) | class SmoothSpanMode(enum.Enum):
  class SmoothCalibConfig (line 37) | class SmoothCalibConfig(SearchBasedCalibConfig):
    method __post_init__ (line 98) | def __post_init__(self) -> None:  # noqa: C901
    method get_alpha_beta_pairs (line 134) | def get_alpha_beta_pairs(self) -> list[tuple[float, float]]:  # noqa: ...
    method generate_dirnames (line 224) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class SkipBasedSmoothCalibConfig (line 264) | class SkipBasedSmoothCalibConfig(SkipBasedConfig, SmoothCalibConfig):
  class SmoothAttentionCalibConfig (line 311) | class SmoothAttentionCalibConfig(SmoothCalibConfig):
  class SmoothTransfomerConfig (line 349) | class SmoothTransfomerConfig:
    method enabled_proj (line 363) | def enabled_proj(self) -> bool:
    method enabled_attn (line 368) | def enabled_attn(self) -> bool:
    method generate_dirnames (line 372) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:

FILE: deepcompressor/calib/lowrank.py
  class QuantLowRankCalibrator (line 20) | class QuantLowRankCalibrator(SearchBasedCalibrator[QuantLowRankCalibConf...
    method __init__ (line 23) | def __init__(
    method population_size (line 58) | def population_size(self) -> int:
    method allows_x_quant_for_wgts (line 63) | def allows_x_quant_for_wgts(self) -> bool:
    method allows_w_quant_for_wgts (line 68) | def allows_w_quant_for_wgts(self) -> bool:
    method is_done (line 72) | def is_done(self) -> bool:
    method is_last_iter (line 76) | def is_last_iter(self) -> bool:
    method _reset (line 80) | def _reset(self, x_wgts: list[torch.Tensor | nn.Parameter], **kwargs) ...
    method get_best (line 109) | def get_best(self) -> LowRankBranch:
    method _ask (line 118) | def _ask(self) -> LowRankBranch:
    method _tell (line 156) | def _tell(self, error: list[torch.Tensor]) -> None:  # noqa: C901
    method _process_x_in_xw (line 188) | def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_w_in_xw (line 193) | def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
    method _process_y_in_yx (line 199) | def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSI...
    method _process_x_in_yx (line 202) | def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_xw_in_yx (line 205) | def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_yw_in_yx (line 208) | def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_wgts_centric_mod (line 211) | def _process_wgts_centric_mod(

FILE: deepcompressor/calib/metric.py
  class ChannelMetric (line 13) | class ChannelMetric:
    method _normalize (line 17) | def _normalize(
    method _abs_max (line 32) | def _abs_max(
    method _abs_sum (line 49) | def _abs_sum(
    method _abs_normalize_sum (line 61) | def _abs_normalize_sum(
    method _square_sum (line 77) | def _square_sum(
    method _max_reduce (line 89) | def _max_reduce(
    method _sum_reduce (line 115) | def _sum_reduce(
    method abs_max (line 142) | def abs_max(
    method abs_mean (line 155) | def abs_mean(
    method abs_normalize_mean (line 169) | def abs_normalize_mean(
    method root_mean_square (line 184) | def root_mean_square(

FILE: deepcompressor/calib/range.py
  class DynamicRangeCalibrator (line 25) | class DynamicRangeCalibrator(SearchBasedCalibrator[DynamicRangeCalibConf...
    method __init__ (line 28) | def __init__(
    method population_size (line 65) | def population_size(self) -> int:
    method is_clamp_based (line 69) | def is_clamp_based(self) -> bool:
    method _reset (line 73) | def _reset(  # noqa: C901
    method get_best (line 152) | def get_best(self) -> DynamicRange:
    method _ask (line 166) | def _ask(self) -> DynamicRange:
    method _tell (line 183) | def _tell(self, error: list[torch.Tensor]) -> None:  # noqa: C901
    method _preprocess_with_pre_scale (line 252) | def _preprocess_with_pre_scale(self, t: torch.Tensor) -> torch.Tensor:
    method _process_wxy (line 260) | def _process_wxy(self, tensor: torch.Tensor, channels_dim: int | _MISS...
    method _process_x_in_xw (line 277) | def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_w_in_xw (line 282) | def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
    method _process_y_in_yx (line 287) | def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSI...
    method _process_x_in_yx (line 292) | def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_xw_in_yx (line 295) | def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_yw_in_yx (line 298) | def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
  function calibrate_dynamic_range (line 302) | def calibrate_dynamic_range(

FILE: deepcompressor/calib/reorder.py
  class ChannelReorderer (line 29) | class ChannelReorderer(BaseTensorProcessor):
    method is_enabled (line 39) | def is_enabled(self) -> bool:
    method get_input_packager (line 42) | def get_input_packager(self) -> BaseInputPackager | None:
    method get_output_packager (line 45) | def get_output_packager(self) -> BaseOutputPackager | None:
    method process (line 48) | def process(self, tensor: torch.Tensor) -> torch.Tensor:
  function get_channel_index_from_rank (line 61) | def get_channel_index_from_rank(
  function get_channel_metric (line 91) | def get_channel_metric(
  function update_channel_metric (line 160) | def update_channel_metric(
  function init_channel_index_from_metric (line 211) | def init_channel_index_from_metric(
  class ChannelOrderCalibrator (line 268) | class ChannelOrderCalibrator(SearchBasedCalibrator[ChannelOrderCalibConf...
    method __init__ (line 271) | def __init__(
    method population_size (line 318) | def population_size(self) -> int:
    method allows_x_quant_for_wgts (line 324) | def allows_x_quant_for_wgts(self) -> bool:
    method allows_w_quant_for_wgts (line 329) | def allows_w_quant_for_wgts(self) -> bool:
    method update_channel_metrics (line 333) | def update_channel_metrics(self, weights: list[torch.Tensor | nn.Param...
    method init_channel_indexes (line 357) | def init_channel_indexes(self) -> None:
    method _reset (line 389) | def _reset(self, x_wgts: list[torch.Tensor | nn.Parameter], x_acts: Te...
    method get_best (line 405) | def get_best(self) -> torch.Tensor:
    method _ask (line 413) | def _ask(self) -> torch.Tensor:
    method _tell (line 427) | def _tell(self, errors: list[tuple[torch.Tensor, ...]]) -> None:  # no...
    method _get_error_str (line 457) | def _get_error_str(self, e: list[int | float]) -> str:
    method _get_metric_index_mode_str (line 460) | def _get_metric_index_mode_str(self, candidate_id: int) -> str:
    method _process_x_in_xw (line 472) | def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_w_in_xw (line 484) | def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
    method _process_x_in_yx (line 494) | def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int) -> torc...
    method _process_y_in_yx (line 497) | def _process_y_in_yx(self, x: torch.Tensor, channels_dim: int) -> torc...
    method _process_xw_in_yx (line 500) | def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_yw_in_yx (line 503) | def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_wgts_centric_mod (line 506) | def _process_wgts_centric_mod(
    method _recover_mod (line 537) | def _recover_mod(self) -> None:

FILE: deepcompressor/calib/rotate.py
  class RMSNorm (line 24) | class RMSNorm(nn.Module):
    method __init__ (line 27) | def __init__(self, hidden_size: int, eps=1e-6) -> None:
    method forward (line 33) | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  class HadamardTransformHook (line 42) | class HadamardTransformHook(IOHook):
    method __init__ (line 43) | def __init__(
    method pre_forward (line 52) | def pre_forward(
  function rotate_in_channels (line 66) | def rotate_in_channels(weight: nn.Parameter, /, *, rotation: torch.Tenso...
  function rotate_out_channels (line 76) | def rotate_out_channels(weight: nn.Parameter, /, *, rotation: torch.Tens...
  function hadamard_in_channels (line 98) | def hadamard_in_channels(
  function get_rotation_matrix (line 120) | def get_rotation_matrix(num_channels: int, random: bool = True, compatib...
  function transform_rms_norm_and_linear (line 136) | def transform_rms_norm_and_linear(norm: nn.LayerNorm | RMSNorm, next_mod...
  function transform_layer_norm_to_rms_norm (line 165) | def transform_layer_norm_to_rms_norm(
  function transform_norm_and_linear (line 219) | def transform_norm_and_linear(

FILE: deepcompressor/calib/search.py
  function _reshape_w_for_wgts (line 26) | def _reshape_w_for_wgts(w: torch.Tensor, w_view_shape: torch.Size) -> to...
  function _reshape_x_for_wgts (line 35) | def _reshape_x_for_wgts(x: torch.Tensor, w_view_shape: torch.Size) -> to...
  function _reshape_x_for_ipts (line 45) | def _reshape_x_for_ipts(x: torch.Tensor, x_view_shape: torch.Size) -> to...
  function _reshape_w_for_ipts (line 55) | def _reshape_w_for_ipts(w: torch.Tensor, x_view_shape: torch.Size) -> to...
  class SearchBasedCalibrator (line 63) | class SearchBasedCalibrator(ABC, tp.Generic[_CONFIG, _CANDIDATE]):
    method __init__ (line 69) | def __init__(
    method population_size (line 129) | def population_size(self) -> int:
    method allows_x_quant_for_wgts (line 134) | def allows_x_quant_for_wgts(self) -> bool:
    method allows_w_quant_for_wgts (line 139) | def allows_w_quant_for_wgts(self) -> bool:
    method allows_x_quant_for_ipts (line 144) | def allows_x_quant_for_ipts(self) -> bool:
    method allows_w_quant_for_ipts (line 149) | def allows_w_quant_for_ipts(self) -> bool:
    method allows_x_quant_for_opts (line 154) | def allows_x_quant_for_opts(self) -> bool:
    method allows_y_quant_for_opts (line 159) | def allows_y_quant_for_opts(self) -> bool:
    method allows_w_quant_for_opts (line 164) | def allows_w_quant_for_opts(self) -> bool:
    method needs_to_pre_reshape_x_for_wgts (line 169) | def needs_to_pre_reshape_x_for_wgts(self) -> bool:
    method needs_to_pre_reshape_w_for_ipts (line 174) | def needs_to_pre_reshape_w_for_ipts(self) -> bool:
    method _reset (line 178) | def _reset(self, **kwargs) -> None:
    method reset (line 181) | def reset(self, **kwargs) -> None:
    method is_done (line 189) | def is_done(self) -> bool:
    method is_last_iter (line 193) | def is_last_iter(self) -> bool:
    method is_last_candidate_in_iter (line 197) | def is_last_candidate_in_iter(self) -> bool:
    method get_best (line 202) | def get_best(self) -> _CANDIDATE:
    method _ask (line 212) | def _ask(self) -> _CANDIDATE:
    method _tell (line 222) | def _tell(self, error: list[torch.Tensor]) -> None:
    method ask (line 231) | def ask(self) -> _CANDIDATE:
    method tell (line 241) | def tell(self, error: list[torch.Tensor]) -> None:
    method _parse_ipts (line 254) | def _parse_ipts(self, ipts: TensorsCache | None, set_device: bool = Fa...
    method _parse_args (line 287) | def _parse_args(  # noqa: C901
    method _reshape_w_for_wgts_centric_partial_products (line 435) | def _reshape_w_for_wgts_centric_partial_products(self, w: torch.Tensor...
    method _reshape_x_for_wgts_centric_partial_products (line 438) | def _reshape_x_for_wgts_centric_partial_products(
    method _reshape_w_for_ipts_centric_partial_products (line 443) | def _reshape_w_for_ipts_centric_partial_products(self, w: torch.Tensor...
    method _reshape_x_for_ipts_centric_partial_products (line 446) | def _reshape_x_for_ipts_centric_partial_products(
    method _reshape_w_for_full_products (line 451) | def _reshape_w_for_full_products(self, w: torch.Tensor, *, view_shape:...
    method _reshape_x_for_full_products (line 454) | def _reshape_x_for_full_products(
    method _process_x_in_xw (line 462) | def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_w_in_xw (line 465) | def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor: ...
    method _process_y_in_yx (line 468) | def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSI...
    method _process_x_in_yx (line 471) | def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_xw_in_yx (line 474) | def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor: ...
    method _process_yw_in_yx (line 477) | def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor: ...
    method _recover_mod (line 479) | def _recover_mod(self) -> None:
    method _process_wgts_centric_mod (line 487) | def _process_wgts_centric_mod(
    method _process_ipts_centric_mod (line 498) | def _process_ipts_centric_mod(
    method _process_opts_centric_mod (line 509) | def _process_opts_centric_mod(
    method calibrate (line 532) | def calibrate(
    method _calibrate_wgts (line 678) | def _calibrate_wgts(  # noqa: C901
    method _calibrate_ipts (line 839) | def _calibrate_ipts(  # noqa: C901
    method _calibrate_opts (line 996) | def _calibrate_opts(  # noqa: C901

FILE: deepcompressor/calib/smooth.py
  class ActivationSmoother (line 35) | class ActivationSmoother(BaseTensorProcessor):
    method is_enabled (line 47) | def is_enabled(self) -> bool:
    method get_input_packager (line 50) | def get_input_packager(self) -> BaseInputPackager | None:
    method get_output_packager (line 53) | def get_output_packager(self) -> BaseOutputPackager | None:
    method process (line 56) | def process(self, tensor: torch.Tensor) -> torch.Tensor:
  function get_smooth_span (line 82) | def get_smooth_span(
  function get_smooth_scale (line 117) | def get_smooth_scale(*, alpha_base: torch.Tensor, beta_base: torch.Tenso...
  class SmoothCalibrator (line 149) | class SmoothCalibrator(SearchBasedCalibrator[SmoothCalibConfig, torch.Te...
    method __init__ (line 152) | def __init__(
    method population_size (line 223) | def population_size(self) -> int:
    method allows_x_quant_for_wgts (line 228) | def allows_x_quant_for_wgts(self) -> bool:
    method allows_w_quant_for_wgts (line 233) | def allows_w_quant_for_wgts(self) -> bool:
    method allows_w_quant_for_ipts (line 238) | def allows_w_quant_for_ipts(self) -> bool:
    method allows_x_quant_for_opts (line 243) | def allows_x_quant_for_opts(self) -> bool:
    method allows_y_quant_for_opts (line 248) | def allows_y_quant_for_opts(self) -> bool:
    method allows_w_quant_for_opts (line 253) | def allows_w_quant_for_opts(self) -> bool:
    method span_mode_pairs (line 258) | def span_mode_pairs(self) -> list[tuple[SmoothSpanMode, SmoothSpanMode]]:
    method alpha_span_modes (line 263) | def alpha_span_modes(self) -> list[SmoothSpanMode]:
    method beta_span_modes (line 268) | def beta_span_modes(self) -> list[SmoothSpanMode]:
    method _reset (line 272) | def _reset(  # noqa: C901
    method _split_candidate_id (line 388) | def _split_candidate_id(self, candidate_id: int) -> tuple[int, int]:
    method get_best (line 403) | def get_best(self) -> torch.Tensor:
    method _ask (line 412) | def _ask(self) -> torch.Tensor:
    method _tell (line 428) | def _tell(self, error: list[torch.Tensor]) -> None:  # noqa: C901
    method _reshape_scale (line 495) | def _reshape_scale(
    method _process_x_in_xw (line 504) | def _process_x_in_xw(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_w_in_xw (line 519) | def _process_w_in_xw(self, w: torch.Tensor) -> torch.Tensor:
    method _process_x_in_yx (line 532) | def _process_x_in_yx(self, x: torch.Tensor, channels_dim: int | _MISSI...
    method _process_y_in_yx (line 553) | def _process_y_in_yx(self, y: torch.Tensor, channels_dim: int | _MISSI...
    method _process_xw_in_yx (line 574) | def _process_xw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_yw_in_yx (line 577) | def _process_yw_in_yx(self, w: torch.Tensor) -> torch.Tensor:
    method _process_wgts_centric_mod (line 580) | def _process_wgts_centric_mod(
    method _process_opts_centric_mod (line 627) | def _process_opts_centric_mod(
    method _update_best (line 656) | def _update_best(
  class SmoothLinearCalibrator (line 717) | class SmoothLinearCalibrator(SmoothCalibrator):
    method __init__ (line 720) | def __init__(
  class SmoothAttentionCalibrator (line 757) | class SmoothAttentionCalibrator(SmoothCalibrator):
    method __init__ (line 760) | def __init__(
    method calibrate (line 800) | def calibrate(
  function smooth_upscale_param (line 852) | def smooth_upscale_param(param: nn.Parameter, scale: torch.Tensor, chann...
  function smooth_downscale_param (line 872) | def smooth_downscale_param(param: nn.Parameter, scale: torch.Tensor, cha...
  function convert_smooth_upscale_to_downscale (line 894) | def convert_smooth_upscale_to_downscale(
  function smooth_linear_modules (line 920) | def smooth_linear_modules(
  function smooth_attention (line 1019) | def smooth_attention(

FILE: deepcompressor/csrc/pybind.cpp
  function PYBIND11_MODULE (line 6) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: deepcompressor/data/cache.py
  class ModuleForwardInput (line 18) | class ModuleForwardInput:
    method to (line 24) | def to(self, device: torch.device | str) -> "ModuleForwardInput":
    method update (line 40) | def update(self, x: dict[str | int, tp.Any] | None = None) -> "ModuleF...
  class TensorCache (line 62) | class TensorCache:
    method clear (line 92) | def clear(self):
    method get_factory_kwargs (line 97) | def get_factory_kwargs(self, **kwargs) -> dict[str, tp.Any]:
    method get_standardized_data (line 106) | def get_standardized_data(self, reshape: bool = False) -> list[torch.T...
    method repartition (line 122) | def repartition(self, max_batch_size: int, max_size: int, standardize:...
  class TensorsCache (line 181) | class TensorsCache:
    method __init__ (line 186) | def __init__(self, tensors: OrderedDict[str | int, TensorCache] | Tens...
    method num_tensors (line 191) | def num_tensors(self) -> int:
    method front (line 195) | def front(self) -> TensorCache:
    method items (line 199) | def items(self) -> tp.ItemsView[str | int, TensorCache]:
    method keys (line 203) | def keys(self) -> tp.KeysView[str | int]:
    method values (line 207) | def values(self) -> tp.ValuesView[TensorCache]:
    method __getitem__ (line 211) | def __getitem__(self, key: str | int) -> TensorCache:
    method __iter__ (line 215) | def __iter__(self) -> tp.Iterator[TensorCache]:
    method __len__ (line 219) | def __len__(self) -> int:
    method clear (line 223) | def clear(self):
    method set_num_samples (line 228) | def set_num_samples(self, num_samples: int):
    method extract (line 233) | def extract(self, index: int, kwargs: dict[str, tp.Any]) -> ModuleForw...
  class IOTensorsCache (line 257) | class IOTensorsCache:
    method __init__ (line 263) | def __init__(
    method clear (line 269) | def clear(self):
    method set_num_samples (line 276) | def set_num_samples(self, num_samples: int):

FILE: deepcompressor/data/codebook.py
  class Codebook (line 14) | class Codebook:
    method __post_init__ (line 33) | def __post_init__(self):
    method round (line 37) | def round(self, tensor: torch.Tensor) -> torch.Tensor:
    method to (line 52) | def to(self, *, device: torch.device | None = None, dtype: torch.dtype...
    method construct (line 75) | def construct(
    method build_for_float_point (line 112) | def build_for_float_point(
    method build_for_integer (line 173) | def build_for_integer(

FILE: deepcompressor/data/common.py
  class TensorType (line 9) | class TensorType(enum.Enum):

FILE: deepcompressor/data/dtype.py
  class QuantDataType (line 13) | class QuantDataType:
    method __init__ (line 18) | def __init__(
    method name (line 119) | def name(self) -> str:
    method codebook_name (line 124) | def codebook_name(self) -> str:
    method signed (line 129) | def signed(self) -> bool:
    method unsigned (line 134) | def unsigned(self) -> bool:
    method total_bits (line 139) | def total_bits(self) -> int:
    method exponent_bits (line 144) | def exponent_bits(self) -> int:
    method mantissa_bits (line 149) | def mantissa_bits(self) -> int:
    method has_subnormal (line 154) | def has_subnormal(self) -> bool:
    method has_inf (line 159) | def has_inf(self) -> bool:
    method has_nan (line 164) | def has_nan(self) -> bool:
    method magnitude (line 169) | def magnitude(self) -> bool:
    method is_float_point (line 174) | def is_float_point(self) -> bool:
    method is_integer (line 179) | def is_integer(self) -> bool:
    method is_exponent (line 184) | def is_exponent(self) -> bool:
    method exponent_mask (line 189) | def exponent_mask(self) -> int:
    method mantissa_mask (line 194) | def mantissa_mask(self) -> int:
    method _end_mantissa (line 199) | def _end_mantissa(self) -> int:
    method _end_exponent (line 203) | def _end_exponent(self) -> int:
    method exponent_bias (line 210) | def exponent_bias(self) -> int:
    method max_exponent_value (line 218) | def max_exponent_value(self) -> int:
    method min_exponent_value (line 226) | def min_exponent_value(self) -> int:
    method max_positive_normal_value (line 234) | def max_positive_normal_value(self) -> float:
    method min_positive_normal_value (line 246) | def min_positive_normal_value(self) -> float:
    method max_positive_subnormal (line 251) | def max_positive_subnormal(self) -> float:
    method min_positive_subnormal (line 261) | def min_positive_subnormal(self) -> float:
    method max_value (line 271) | def max_value(self) -> float:
    method min_value (line 276) | def min_value(self) -> float:
    method to_unsigned (line 290) | def to_unsigned(self) -> "QuantDataType":
    method get_codebook (line 299) | def get_codebook(self, *, device: torch.device | str = "cpu", dtype: t...
    method round (line 322) | def round(self, tensor: torch.Tensor) -> torch.Tensor:
    method from_str (line 339) | def from_str(cls, s: str, /) -> "QuantDataType":
    method _build_codebook (line 345) | def _build_codebook(self, *, device: torch.device | str = "cpu", dtype...
    method _build_default_name (line 362) | def _build_default_name(self) -> str:
    method _default_from_str (line 380) | def _default_from_str(s: str, /) -> "QuantDataType":
    method __str__ (line 423) | def __str__(self) -> str:
    method __repr__ (line 426) | def __repr__(self) -> str:
    method __eq__ (line 429) | def __eq__(self, value: object) -> bool:
    method __hash__ (line 434) | def __hash__(self) -> int:
  class _QDTypeMeta (line 438) | class _QDTypeMeta(type):
    method __getattr__ (line 439) | def __getattr__(cls, __name: str) -> tp.Any:
  class QDType (line 446) | class QDType(metaclass=_QDTypeMeta):

FILE: deepcompressor/data/range.py
  class RangeBound (line 17) | class RangeBound:
    method is_set (line 23) | def is_set(self) -> bool:
    method to_dict (line 27) | def to_dict(self) -> dict[str, tp.Any]:
    method from_dict (line 32) | def from_dict(cls, data: dict[str, tp.Any] | None) -> tp.Optional[tp.S...
  class QuantRange (line 37) | class QuantRange(RangeBound):
    method log2 (line 40) | def log2(self) -> "LogQuantRange":
    method intersect (line 48) | def intersect(self, quant_dtype: QuantDataType, *, has_zero_point: boo...
    method intersect_log2 (line 68) | def intersect_log2(self, quant_dtype: QuantDataType) -> "LogQuantRange":
    method construct (line 82) | def construct(
  class LogQuantRange (line 102) | class LogQuantRange(QuantRange):
    method log2 (line 105) | def log2(self) -> "LogQuantRange":
    method intersect (line 109) | def intersect(self, quant_dtype: QuantDataType, *, has_zero_point: boo...
    method intersect_log2 (line 124) | def intersect_log2(self, quant_dtype: QuantDataType) -> "LogQuantRange":
    method construct (line 144) | def construct(
  class ProtectiveQuantRange (line 162) | class ProtectiveQuantRange(QuantRange):
    method construct (line 168) | def construct(
  class DynamicRange (line 242) | class DynamicRange:
    method __post_init__ (line 249) | def __post_init__(self) -> None:
    method is_set (line 253) | def is_set(self) -> bool:
    method intersect (line 257) | def intersect(self, range_bound: RangeBound | None) -> "DynamicRange":
    method measure (line 277) | def measure(  # noqa: C901
    method scale (line 358) | def scale(
    method construct (line 398) | def construct(
    method _format_m_ (line 408) | def _format_m_(
    method to_dict (line 429) | def to_dict(self) -> dict[str, tp.Any]:
    method from_dict (line 434) | def from_dict(cls, data: dict[str, tp.Any] | None) -> tp.Optional[tp.S...

FILE: deepcompressor/data/scale.py
  class QuantScale (line 11) | class QuantScale:
    method __init__ (line 16) | def __init__(self):
    method num_children (line 20) | def num_children(self) -> int:
    method num_leaves (line 25) | def num_leaves(self) -> int:
    method is_quantized (line 29) | def is_quantized(self) -> bool:
    method get_child (line 33) | def get_child(self, index: int) -> "QuantScale":
    method append (line 37) | def append(self, scale: tp.Union[torch.Tensor, "QuantScale"]) -> "Quan...
    method extend (line 51) | def extend(self, scale: "QuantScale") -> "QuantScale":
    method join (line 62) | def join(self, scale: "QuantScale") -> "QuantScale":
    method remove_zero (line 66) | def remove_zero(self) -> "QuantScale":
    method state_dict (line 71) | def state_dict(
  function _join_scale_tensor (line 91) | def _join_scale_tensor(global_scale: torch.Tensor | None, local_scale: t...

FILE: deepcompressor/data/tensor.py
  class QuantTensor (line 11) | class QuantTensor:
    method __init__ (line 20) | def __init__(
    method data (line 39) | def data(self) -> torch.Tensor | None:
    method qdata (line 44) | def qdata(self) -> torch.Tensor | None:

FILE: deepcompressor/data/utils/dtype.py
  function infer_dtype_bits (line 11) | def infer_dtype_bits(dtype: torch.dtype | QuantDataType) -> int:
  function infer_dtype_name (line 43) | def infer_dtype_name(dtype: torch.dtype | QuantDataType) -> str:
  function eval_dtype (line 71) | def eval_dtype(  # noqa: C901

FILE: deepcompressor/data/utils/reshape.py
  class ReshapeFn (line 16) | class ReshapeFn:
    method __call__ (line 19) | def __call__(self, x: torch.Tensor, /, ic_last: bool = True) -> torch....
  class LinearReshapeFn (line 35) | class LinearReshapeFn(ReshapeFn):
    method __call__ (line 38) | def __call__(self, x: torch.Tensor, /, ic_last: bool = True) -> torch....
  class ConvInputReshapeFn (line 54) | class ConvInputReshapeFn(ReshapeFn):
    method __init__ (line 57) | def __init__(
    method __call__ (line 77) | def __call__(self, x: torch.Tensor, /, ic_last: bool = True) -> torch....
  class ConvOutputReshapedFn (line 104) | class ConvOutputReshapedFn(ReshapeFn):
    method __call__ (line 107) | def __call__(self, x: torch.Tensor, /, ic_last: bool = True) -> torch....
  class AttentionInputReshapeFn (line 128) | class AttentionInputReshapeFn(ReshapeFn):
    method __init__ (line 131) | def __init__(self, channels_dim: int) -> None:
    method __call__ (line 140) | def __call__(self, x: torch.Tensor, /, ic_last: bool = True) -> torch....

FILE: deepcompressor/data/utils/scale.py
  function infer_scale_dtypes (line 13) | def infer_scale_dtypes(
  function infer_scale_quant_spans (line 34) | def infer_scale_quant_spans(scale_dtypes: tp.Sequence[QuantDataType], ba...
  function infer_exponent_scale_level (line 42) | def infer_exponent_scale_level(scale_dtypes: tp.Sequence[torch.dtype | Q...

FILE: deepcompressor/data/utils/shape.py
  function infer_group_shape_name (line 14) | def infer_group_shape_name(group_shape: tp.Sequence[int]) -> str:
  function format_group_configs (line 58) | def format_group_configs(
  function infer_group_shapes (line 105) | def infer_group_shapes(group_shapes: tuple[tuple[int, ...], ...], shape:...
  function infer_view_shape (line 142) | def infer_view_shape(
  function infer_scale_view_shapes (line 173) | def infer_scale_view_shapes(
  function infer_shape (line 203) | def infer_shape(view_shape: torch.Size) -> torch.Size:

FILE: deepcompressor/data/zero.py
  class ZeroPointDomain (line 9) | class ZeroPointDomain(enum.Enum):

FILE: deepcompressor/dataset/action.py
  class CacheHook (line 16) | class CacheHook(IOHook):
    method __init__ (line 17) | def __init__(
    method pre_forward (line 47) | def pre_forward(
    method post_forward (line 60) | def post_forward(
  class CacheAction (line 75) | class CacheAction(ABC):
    method __init__ (line 80) | def __init__(self, device: torch.device | str | None = None) -> None:
    method apply (line 90) | def apply(
    method info (line 112) | def info(
    method get_input_packager (line 133) | def get_input_packager(self, name: str, module: nn.Module, cache: Tens...
    method get_output_packager (line 150) | def get_output_packager(self, name: str, module: nn.Module, cache: Ten...
    method register (line 167) | def register(
  class ConcatCacheAction (line 206) | class ConcatCacheAction(CacheAction):
    method apply (line 209) | def apply(
    method info (line 238) | def info(

FILE: deepcompressor/dataset/cache.py
  class BaseCalibCacheLoader (line 28) | class BaseCalibCacheLoader(ABC):
    method __init__ (line 34) | def __init__(self, dataset: torch.utils.data.Dataset, batch_size: int):
    method num_samples (line 47) | def num_samples(self) -> int:
    method iter_samples (line 52) | def iter_samples(self, *args, **kwargs) -> tp.Generator[ModuleForwardI...
    method _init_cache (line 56) | def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
    method _convert_layer_inputs (line 93) | def _convert_layer_inputs(
    method _convert_layer_outputs (line 115) | def _convert_layer_outputs(self, m: nn.Module, outputs: tp.Any) -> dic...
    method _layer_forward_pre_hook (line 133) | def _layer_forward_pre_hook(
    method _iter_layer_activations (line 151) | def _iter_layer_activations(  # noqa: C901
    method iter_layer_activations (line 430) | def iter_layer_activations(  # noqa: C901

FILE: deepcompressor/dataset/config.py
  class BaseDataLoaderConfig (line 17) | class BaseDataLoaderConfig(ABC):
    method generate_dirnames (line 33) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method build_dataset (line 48) | def build_dataset(self, *args, **kwargs) -> Dataset:
    method build_loader (line 53) | def build_loader(self, *args, **kwargs) -> DataLoader | BaseCalibCache...

FILE: deepcompressor/nn/patch/conv.py
  class ConcatConv2d (line 14) | class ConcatConv2d(nn.Module):
    method __init__ (line 15) | def __init__(
    method forward (line 53) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method from_conv2d (line 62) | def from_conv2d(conv: nn.Conv2d, splits: list[int]) -> "ConcatConv2d":
  class ShiftedConv2d (line 91) | class ShiftedConv2d(nn.Module):
    method __init__ (line 94) | def __init__(
    method forward (line 148) | def forward(self, input: torch.Tensor) -> torch.Tensor:
    method from_conv2d (line 155) | def from_conv2d(conv: nn.Conv2d, shift: float | torch.Tensor) -> "Shif...

FILE: deepcompressor/nn/patch/linear.py
  class ConcatLinear (line 10) | class ConcatLinear(nn.Module):
    method __init__ (line 11) | def __init__(
    method forward (line 38) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method from_linear (line 47) | def from_linear(linear: nn.Linear, splits: list[int]) -> "ConcatLinear":
  class ShiftedLinear (line 70) | class ShiftedLinear(nn.Module):
    method __init__ (line 73) | def __init__(
    method in_features (line 96) | def in_features(self) -> int:
    method out_features (line 100) | def out_features(self) -> int:
    method forward (line 103) | def forward(self, input: torch.Tensor) -> torch.Tensor:
    method from_linear (line 107) | def from_linear(linear: nn.Linear, shift: float | torch.Tensor) -> "Sh...

FILE: deepcompressor/nn/patch/lowrank.py
  class LowRankBranch (line 12) | class LowRankBranch(nn.Module):
    method __init__ (line 13) | def __init__(
    method reset_parameters (line 30) | def reset_parameters(self, weight: torch.Tensor | None = None) -> None:
    method get_effective_weight (line 61) | def get_effective_weight(self) -> torch.Tensor | None:
    method forward (line 69) | def forward(self, input: torch.Tensor) -> torch.Tensor | None:
    method as_hook (line 87) | def as_hook(

FILE: deepcompressor/nn/patch/sdpa.py
  class ScaleDotProductAttention (line 13) | class ScaleDotProductAttention(nn.Module):
    method forward (line 14) | def forward(

FILE: deepcompressor/nn/struct/attn.py
  class AttentionConfigStruct (line 28) | class AttentionConfigStruct:
    method head_size (line 60) | def head_size(self) -> int:
    method num_key_value_groups (line 65) | def num_key_value_groups(self) -> int:
    method num_channels (line 70) | def num_channels(self) -> int:
    method num_add_channels (line 75) | def num_add_channels(self) -> int:
    method num_query_channels (line 80) | def num_query_channels(self) -> int:
    method num_key_value_channels (line 85) | def num_key_value_channels(self) -> int:
    method num_head_channels (line 90) | def num_head_channels(self) -> int:
    method num_head_repeats (line 95) | def num_head_repeats(self) -> int:
  class FeedForwardConfigStruct (line 101) | class FeedForwardConfigStruct:
    method num_channels (line 125) | def num_channels(self) -> int:
    method num_intermediate_channels (line 130) | def num_intermediate_channels(self) -> int:
    method intermediate_lowerbound (line 135) | def intermediate_lowerbound(self) -> float | None:
    method infer_lowerbound (line 140) | def infer_lowerbound(act_type: str) -> float | None:
  class AttentionStruct (line 158) | class AttentionStruct(BaseModuleStruct):
    method qkv_proj (line 236) | def qkv_proj(self) -> list[nn.Linear]:
    method add_qkv_proj (line 240) | def add_qkv_proj(self) -> list[nn.Linear]:
    method out_proj (line 249) | def out_proj(self) -> nn.Linear:
    method add_out_proj (line 253) | def add_out_proj(self) -> nn.Linear:
    method qkv_proj_rnames (line 257) | def qkv_proj_rnames(self) -> list[str]:
    method add_qkv_proj_rnames (line 263) | def add_qkv_proj_rnames(self) -> list[str]:
    method out_proj_rname (line 272) | def out_proj_rname(self) -> str:
    method add_out_proj_rname (line 276) | def add_out_proj_rname(self) -> str:
    method qkv_proj_names (line 280) | def qkv_proj_names(self) -> list[str]:
    method add_qkv_proj_names (line 284) | def add_qkv_proj_names(self) -> list[str]:
    method out_proj_name (line 293) | def out_proj_name(self) -> str:
    method add_out_proj_name (line 297) | def add_out_proj_name(self) -> str:
    method __post_init__ (line 302) | def __post_init__(self) -> None:
    method is_self_attn (line 370) | def is_self_attn(self) -> bool:
    method is_cross_attn (line 373) | def is_cross_attn(self) -> bool:
    method is_joint_attn (line 376) | def is_joint_attn(self) -> bool:
    method filter_kwargs (line 379) | def filter_kwargs(self, kwargs: dict) -> dict:
    method named_key_modules (line 383) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 399) | def iter_attention_structs(self) -> tp.Generator[tp.Self, None, None]:
    method get_default_keys (line 403) | def get_default_keys(cls) -> list[str]:
  class SelfAttentionStruct (line 409) | class SelfAttentionStruct(AttentionStruct):
    method get_default_keys (line 430) | def get_default_keys(cls) -> list[str]:
  class CrossAttentionStruct (line 436) | class CrossAttentionStruct(AttentionStruct):
    method get_default_keys (line 457) | def get_default_keys(cls) -> list[str]:
  class JointAttentionStruct (line 463) | class JointAttentionStruct(AttentionStruct):
  class FeedForwardStruct (line 481) | class FeedForwardStruct(BaseModuleStruct):
    method __post_init__ (line 521) | def __post_init__(self) -> None:
    method named_key_modules (line 561) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
  class TransformerBlockStruct (line 577) | class TransformerBlockStruct(BaseModuleStruct):
    method __post_init__ (line 645) | def __post_init__(self) -> None:
    method named_key_modules (line 705) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 713) | def iter_attention_structs(self) -> tp.Generator[AttentionStruct, None...
    method iter_transformer_block_structs (line 717) | def iter_transformer_block_structs(self) -> tp.Generator[tp.Self, None...
  class BaseTransformerStruct (line 722) | class BaseTransformerStruct(BaseModuleStruct):
    method num_blocks (line 759) | def num_blocks(self) -> int:
    method block_structs (line 765) | def block_structs(self) -> list[TransformerBlockStruct]:
    method block_names (line 771) | def block_names(self) -> list[str]:
    method __post_init__ (line 775) | def __post_init__(self) -> None:
    method named_key_modules (line 787) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method iter_attention_structs (line 795) | def iter_attention_structs(self) -> tp.Generator[AttentionStruct, None...
    method iter_transformer_block_structs (line 799) | def iter_transformer_block_structs(self) -> tp.Generator[TransformerBl...

FILE: deepcompressor/nn/struct/base.py
  class BaseModuleStruct (line 17) | class BaseModuleStruct(ABC):
    method __post_init__ (line 42) | def __post_init__(self) -> None:
    method __call__ (line 60) | def __call__(self, *args: tp.Any, **kwds: tp.Any) -> tp.Any:
    method named_key_modules (line 64) | def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Modu...
    method get_default_keys (line 69) | def get_default_keys(cls) -> list[str]:
    method register_factory (line 74) | def register_factory(
    method construct (line 116) | def construct(

FILE: deepcompressor/quantizer/config/base.py
  class BaseQuantizerConfig (line 25) | class BaseQuantizerConfig(EnableConfig):
    method quant_dtype (line 30) | def quant_dtype(self) -> QuantDataType | None:
    method zero_domain (line 36) | def zero_domain(self) -> ZeroPointDomain | None:
    method largest_group_shape (line 42) | def largest_group_shape(self) -> tp.Sequence[int]:
    method smallest_group_shape (line 48) | def smallest_group_shape(self) -> tp.Sequence[int]:
    method is_enabled (line 52) | def is_enabled(self) -> bool:
    method decompose (line 57) | def decompose(self) -> "DecomposedQuantizerConfig":
    method generate_dirnames (line 61) | def generate_dirnames(
  class DecomposedQuantizerConfig (line 90) | class DecomposedQuantizerConfig(BaseQuantizerConfig):
    method quant_dtype (line 95) | def quant_dtype(self) -> QuantDataType | None:
    method zero_domain (line 99) | def zero_domain(self) -> ZeroPointDomain | None:
    method largest_group_shape (line 103) | def largest_group_shape(self) -> tp.Sequence[int]:
    method smallest_group_shape (line 107) | def smallest_group_shape(self) -> tp.Sequence[int]:
    method num_steps (line 111) | def num_steps(self) -> int:
    method decompose (line 114) | def decompose(self) -> "DecomposedQuantizerConfig":
    method __eq__ (line 117) | def __eq__(self, value: object) -> bool:
    method _get_effective_bits (line 135) | def _get_effective_bits(
    method _get_dtype_name (line 165) | def _get_dtype_name(self, default_dtype: torch.dtype = torch.float16) ...
    method _get_group_shapes_name (line 185) | def _get_group_shapes_name(self, default_dtype: torch.dtype = torch.fl...
    method generate_dirnames (line 212) | def generate_dirnames(
  class QuantizerConfig (line 249) | class QuantizerConfig(BaseQuantizerConfig):
    method __post_init__ (line 273) | def __post_init__(self) -> None:
    method quant_dtype (line 281) | def quant_dtype(self) -> QuantDataType | None:
    method zero_domain (line 286) | def zero_domain(self) -> ZeroPointDomain | None:
    method largest_group_shape (line 291) | def largest_group_shape(self) -> tp.Sequence[int]:
    method smallest_group_shape (line 296) | def smallest_group_shape(self) -> tp.Sequence[int]:
    method decompose (line 300) | def decompose(self) -> DecomposedQuantizerConfig:
  class ProgressiveQuantizerConfig (line 307) | class ProgressiveQuantizerConfig(QuantizerConfig):
    method __post_init__ (line 335) | def __post_init__(self) -> None:
    method decompose (line 357) | def decompose(self) -> DecomposedQuantizerConfig:

FILE: deepcompressor/quantizer/config/kernel.py
  class BaseQuantKernel (line 18) | class BaseQuantKernel(ABC):
    method quantize (line 22) | def quantize(
  class BaseQuantKernelConfig (line 60) | class BaseQuantKernelConfig(ABC):
    method name (line 65) | def name(self) -> str:
    method build (line 70) | def build(self) -> BaseQuantKernel:
    method generate_dirnames (line 75) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class BaseKeyEnableQuantKernelConfig (line 91) | class BaseKeyEnableQuantKernelConfig(KeyEnableConfig, EnableConfig):
    method __post_init__ (line 99) | def __post_init__(self) -> None:
    method is_enabled (line 102) | def is_enabled(self) -> bool:
    method is_enabled_for (line 105) | def is_enabled_for(self, key: str) -> bool:
    method specialize_for (line 108) | def specialize_for(self, key: str) -> BaseQuantKernelConfig | None:
    method generate_dirnames (line 121) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method organize (line 140) | def organize(self) -> None:

FILE: deepcompressor/quantizer/config/lowrank.py
  class QuantLowRankConfig (line 15) | class QuantLowRankConfig(EnableConfig):
    method is_enabled (line 31) | def is_enabled(self) -> bool:
    method generate_dirnames (line 34) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:

FILE: deepcompressor/quantizer/impl/base.py
  class QuantizerImpl (line 23) | class QuantizerImpl:
    method is_enabled (line 41) | def is_enabled(self) -> bool:
    method quantize (line 49) | def quantize(
    method _quantize (line 132) | def _quantize(  # noqa: C901
    method update (line 289) | def update(

FILE: deepcompressor/quantizer/impl/info.py
  class QuantStepInfo (line 19) | class QuantStepInfo:
    method __post_init__ (line 39) | def __post_init__(self):
    method tensor_zero_domain (line 52) | def tensor_zero_domain(self) -> ZeroPointDomain | None:
    method tensor_quant_range (line 56) | def tensor_quant_range(self) -> QuantRange:
    method tensor_range_bound (line 61) | def tensor_range_bound(self) -> RangeBound | None:
    method to_config (line 64) | def to_config(self) -> QuantizerConfig:
    method construct (line 73) | def construct(
  class QuantInfo (line 97) | class QuantInfo:
    method num_steps (line 102) | def num_steps(self) -> int:
    method get_child (line 105) | def get_child(self, idx: int) -> QuantStepInfo:
    method is_outdated (line 108) | def is_outdated(
    method construct (line 141) | def construct(

FILE: deepcompressor/quantizer/impl/scale.py
  function quantize_scale (line 20) | def quantize_scale(
  class QuantScaleInfo (line 68) | class QuantScaleInfo:
    method has_zero_point (line 95) | def has_zero_point(self) -> bool:
    method __post_init__ (line 98) | def __post_init__(self):
    method quantize (line 142) | def quantize(

FILE: deepcompressor/quantizer/impl/simple.py
  function simple_quantize (line 13) | def simple_quantize(

FILE: deepcompressor/quantizer/impl/ste.py
  class STEFunction (line 11) | class STEFunction(torch.autograd.Function):
    method forward (line 15) | def forward(ctx: tp.Any, tensor: torch.Tensor, fn: tp.Callable[[torch....
    method backward (line 20) | def backward(ctx: tp.Any, grad_output: torch.Tensor) -> tp.Tuple[torch...
  function ste (line 25) | def ste(tensor: torch.Tensor, fn: tp.Callable[[torch.Tensor], torch.Tens...

FILE: deepcompressor/quantizer/kernel/gptq.py
  class QuantGptqConfig (line 25) | class QuantGptqConfig(BaseQuantKernelConfig):
    method name (line 45) | def name(self) -> str:
    method build (line 48) | def build(self) -> "QuantGptqKernel":
    method generate_dirnames (line 51) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
  class QuantGptqKernel (line 66) | class QuantGptqKernel(BaseQuantKernel):
    method __init__ (line 67) | def __init__(self, config: "QuantGptqConfig"):
    method quantize (line 70) | def quantize(
  function gptq_quantize (line 129) | def gptq_quantize(  # noqa: C901

FILE: deepcompressor/quantizer/kernel/rtn.py
  class QuantRtnKernel (line 15) | class QuantRtnKernel(BaseQuantKernel):
    method quantize (line 18) | def quantize(
  function rtn_quantize (line 68) | def rtn_quantize(

FILE: deepcompressor/quantizer/processor.py
  class Quantizer (line 24) | class Quantizer(QuantizerImpl, BaseTensorProcessor):
    method is_enabled_low_rank (line 77) | def is_enabled_low_rank(self) -> bool:
    method get_input_packager (line 84) | def get_input_packager(self) -> BaseInputPackager | None:
    method get_output_packager (line 87) | def get_output_packager(self) -> BaseOutputPackager | None:
    method process (line 90) | def process(self, tensor: torch.Tensor) -> torch.Tensor:
    method quantize (line 93) | def quantize(
    method update (line 185) | def update(
    method quantize_with_low_rank (line 215) | def quantize_with_low_rank(
    method state_dict (line 312) | def state_dict(self, device: torch.device | str = "cpu") -> dict[str, ...
    method load_state_dict (line 341) | def load_state_dict(self, state_dict: dict[str, tp.Any], device: torch...

FILE: deepcompressor/utils/common.py
  function join_name (line 22) | def join_name(prefix: str, name: str, sep: str = ".", relative: bool = T...
  function join_names (line 56) | def join_names(*names: str, sep: str = ".", relative: bool = True) -> str:
  function num2str (line 76) | def num2str(num: int | float) -> str:
  function split_sequence (line 93) | def split_sequence(lst: tp.Sequence[tp.Any], splits: tp.Sequence[int]) -...
  function tree_map (line 115) | def tree_map(func: tp.Callable[[tp.Any], tp.Any], tree: tp.Any) -> tp.Any:
  function tree_copy_with_ref (line 127) | def tree_copy_with_ref(
  function tree_split (line 148) | def tree_split(tree: tp.Any) -> list[tp.Any]:
  function tree_collate (line 187) | def tree_collate(batch: list[tp.Any] | tuple[tp.Any, ...]) -> tp.Any:
  function hash_str_to_int (line 203) | def hash_str_to_int(s: str) -> int:

FILE: deepcompressor/utils/config/base.py
  class EnableConfig (line 13) | class EnableConfig(ABC):
    method is_enabled (line 15) | def is_enabled(self) -> bool:
  class KeyEnableConfig (line 20) | class KeyEnableConfig(ABC):
    method is_enabled_for (line 22) | def is_enabled_for(self, key: str) -> bool:
  class SkipBasedConfig (line 29) | class SkipBasedConfig(KeyEnableConfig, EnableConfig):
    method __post_init__ (line 39) | def __post_init__(self) -> None:
    method is_enabled (line 45) | def is_enabled(self) -> bool:
    method is_enabled_for (line 49) | def is_enabled_for(self, key: str) -> bool:
    method generate_dirnames (line 62) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method update_get_arguments (line 79) | def update_get_arguments(
    method update_from_dict (line 103) | def update_from_dict(
  class IncludeBasedConfig (line 118) | class IncludeBasedConfig(KeyEnableConfig, EnableConfig):
    method __post_init__ (line 128) | def __post_init__(self) -> None:
    method is_enabled (line 134) | def is_enabled(self) -> bool:
    method is_enabled_for (line 138) | def is_enabled_for(self, key: str) -> bool:
    method generate_dirnames (line 151) | def generate_dirnames(self, *, prefix: str = "", **kwargs) -> list[str]:
    method update_get_arguments (line 171) | def update_get_arguments(
    method update_from_dict (line 195) | def update_from_dict(

FILE: deepcompressor/utils/config/model.py
  class BaseModelConfig (line 16) | class BaseModelConfig(ABC):
    method __post_init__ (line 41) | def __post_init__(self):
    method build (line 53) | def build(self, *args, **kwargs) -> tp.Any:

FILE: deepcompressor/utils/config/output.py
  class OutputConfig (line 15) | class OutputConfig:
    method __post_init__ (line 39) | def __post_init__(self):
    method running_dirpath (line 44) | def running_dirpath(self) -> str:
    method error_dirpath (line 49) | def error_dirpath(self) -> str:
    method job_dirname (line 54) | def job_dirname(self) -> str:
    method job_dirpath (line 59) | def job_dirpath(self) -> str:
    method running_job_dirname (line 64) | def running_job_dirname(self) -> str:
    method error_job_dirname (line 69) | def error_job_dirname(self) -> str:
    method running_job_dirpath (line 74) | def running_job_dirpath(self) -> str:
    method lock (line 78) | def lock(self) -> None:
    method unlock (line 89) | def unlock(self, error: bool = False) -> None:
    method is_locked_by_others (line 96) | def is_locked_by_others(self) -> bool:
    method get_running_path (line 104) | def get_running_path(self, filename: str) -> str:
    method get_running_job_path (line 109) | def get_running_job_path(self, filename: str) -> str:
    method generate_timestamp (line 115) | def generate_timestamp() -> str:

FILE: deepcompressor/utils/config/path.py
  class BasePathConfig (line 12) | class BasePathConfig:
    method is_all_set (line 15) | def is_all_set(self) -> bool:
    method is_all_empty (line 28) | def is_all_empty(self) -> bool:
    method clone (line 41) | def clone(self) -> tp.Self:
    method add_parent_dirs (line 51) | def add_parent_dirs(self, *parent_dirs: str) -> tp.Self:
    method add_children (line 65) | def add_children(self, *children: str) -> tp.Self:
    method to_dirpath (line 79) | def to_dirpath(self) -> tp.Self:
    method apply (line 88) | def apply(self, fn: tp.Callable) -> tp.Self:

FILE: deepcompressor/utils/dataclass.py
  function get_fields (line 9) | def get_fields(class_or_instance, *, init_vars: bool = False, class_vars...

FILE: deepcompressor/utils/hooks/branch.py
  class AccumBranchHook (line 15) | class AccumBranchHook(IOHook):
    method __init__ (line 18) | def __init__(
    method pre_forward (line 28) | def pre_forward(
    method post_forward (line 43) | def post_forward(

FILE: deepcompressor/utils/hooks/hook.py
  class Hook (line 17) | class Hook:
    method __init__ (line 25) | def __init__(self, *, pre: bool, post: bool) -> None:
    method is_in_hook (line 44) | def is_in_hook(self) -> bool:
    method is_out_hook (line 48) | def is_out_hook(self) -> bool:
    method is_inout_hook (line 52) | def is_inout_hook(self) -> bool:
    method activate (line 56) | def activate(self) -> tp.Self:
    method deactivate (line 61) | def deactivate(self) -> tp.Self:
    method pre_forward (line 66) | def pre_forward(
    method post_forward (line 81) | def post_forward(
    method __call__ (line 102) | def __call__(self, *args, **kwargs) -> tp.Any:
    method register (line 113) | def register(
    method remove (line 143) | def remove(self, module: nn.Module | tp.Iterable[nn.Module] | None = N...
  class EarlyStopException (line 167) | class EarlyStopException(Exception):
  class EarlyStopHook (line 173) | class EarlyStopHook(Hook):
    method __init__ (line 174) | def __init__(self):
    method pre_forward (line 177) | def pre_forward(self, *args, **kwargs) -> None:
  class IOHook (line 181) | class IOHook(Hook):
    method __init__ (line 189) | def __init__(

FILE: deepcompressor/utils/hooks/packager.py
  class BaseInputPackager (line 24) | class BaseInputPackager(ABC):
    method unpack (line 28) | def unpack(
    method repack (line 48) | def repack(
  class SimpleInputPackager (line 74) | class SimpleInputPackager(BaseInputPackager):
    method unpack (line 75) | def unpack(
    method repack (line 80) | def repack(
  class KeyedInputPackager (line 90) | class KeyedInputPackager(BaseInputPackager):
    method __init__ (line 91) | def __init__(self, module: nn.Module, index_or_keys: list[int | str]):
    method unpack (line 127) | def unpack(
    method repack (line 139) | def repack(
  class BaseOutputPackager (line 156) | class BaseOutputPackager(ABC):
    method unpack (line 160) | def unpack(
    method repack (line 186) | def repack(
  class SimpleOutputPackager (line 215) | class SimpleOutputPackager(BaseOutputPackager):
    method unpack (line 216) | def unpack(
    method repack (line 227) | def repack(
  class KeyedOutputPackager (line 241) | class KeyedOutputPackager(BaseOutputPackager):
    method __init__ (line 242) | def __init__(self, index_or_keys: list[int | str]):
    method unpack (line 245) | def unpack(
    method repack (line 268) | def repack(

FILE: deepcompressor/utils/hooks/processor.py
  class BaseTensorProcessor (line 18) | class BaseTensorProcessor(abc.ABC):
    method is_enabled (line 20) | def is_enabled(self) -> bool: ...
    method get_input_packager (line 23) | def get_input_packager(self) -> BaseInputPackager | None: ...
    method get_output_packager (line 26) | def get_output_packager(self) -> BaseOutputPackager | None: ...
    method process (line 29) | def process(self, tensor: torch.Tensor) -> torch.Tensor: ...
    method as_hook (line 31) | def as_hook(
  class ProcessHook (line 49) | class ProcessHook(IOHook):
    method __init__ (line 50) | def __init__(
    method process (line 65) | def process(self, tensors: dict[int | str, torch.Tensor]) -> dict[int ...
    method pre_forward (line 74) | def pre_forward(
    method post_forward (line 83) | def post_forward(

FILE: deepcompressor/utils/math/functional.py
  function is_pow2 (line 9) | def is_pow2(n: int) -> bool:
  function root_ (line 23) | def root_(y: torch.Tensor, index: float) -> torch.Tensor:

FILE: deepcompressor/utils/math/hadamard.py
  function _matmul_hadU (line 17) | def _matmul_hadU(X: torch.Tensor, hadamard_K: torch.Tensor | None, K: in...
  function random_hadamard_matrix (line 38) | def random_hadamard_matrix(size: int) -> torch.Tensor:
  function hardmard_transform (line 48) | def hardmard_transform(
  class HadamardMatrix (line 65) | class HadamardMatrix:
    method get (line 72) | def get(
    method get_lhs (line 89) | def get_lhs(n: int) -> tuple[torch.FloatTensor, int]:
    method _get_hadamard_k (line 99) | def _get_hadamard_k(k: int) -> torch.FloatTensor:
    method _get_hadamard_12 (line 106) | def _get_hadamard_12() -> torch.FloatTensor:
    method _get_hadamard_40 (line 125) | def _get_hadamard_40() -> torch.FloatTensor:
    method _get_hadamard_20 (line 1812) | def _get_hadamard_20() -> torch.FloatTensor:
    method _get_hadamard_28 (line 1839) | def _get_hadamard_28() -> torch.FloatTensor:
    method _get_hadamard_36 (line 2686) | def _get_hadamard_36() -> torch.FloatTensor:
    method _get_hadamard_60 (line 4061) | def _get_hadamard_60() -> torch.FloatTensor:
    method _get_hadamard_52 (line 7788) | def _get_hadamard_52() -> torch.FloatTensor:
    method _get_hadamard_108 (line 10603) | def _get_hadamard_108() -> torch.FloatTensor:
    method _get_hadamard_140 (line 22490) | def _get_hadamard_140() -> torch.FloatTensor:
    method _get_hadamard_156 (line 42377) | def _get_hadamard_156() -> torch.FloatTensor:
    method _get_hadamard_172 (line 67032) | def _get_hadamard_172() -> torch.FloatTensor:

FILE: deepcompressor/utils/patch.py
  function copy_func (line 14) | def copy_func(f: types.FunctionType, globals: dict[str, typing.Any] | No...
  function get_module_parents_map (line 39) | def get_module_parents_map(

FILE: deepcompressor/utils/tools/logging.py
  function getLogger (line 48) | def getLogger(name: str | None = None) -> logging.Logger:
  function log (line 60) | def log(level: int, msg: str, logger: logging.Logger | None = None) -> N...
  function info (line 81) | def info(msg: str, logger: logging.Logger | None = None):
  function debug (line 92) | def debug(msg: str, logger: logging.Logger | None = None):
  function warning (line 103) | def warning(msg: str, logger: logging.Logger | None = None):
  function error (line 114) | def error(msg: str, logger: logging.Logger | None = None):
  function critical (line 125) | def critical(msg: str, logger: logging.Logger | None = None):
  class Formatter (line 136) | class Formatter(logging.Formatter):
    method __init__ (line 141) | def __init__(self, fmt: str | None = None, datefmt: str | None = None,...
    method format (line 151) | def format(self, record: logging.LogRecord) -> str:
    method indent_inc (line 178) | def indent_inc(delta: int = 2):
    method indent_dec (line 183) | def indent_dec(delta: int = 2):
    method indent_reset (line 188) | def indent_reset(indent: int = 0):
  function basicConfig (line 193) | def basicConfig(**kwargs) -> None:
  function setup (line 203) | def setup(

FILE: deepcompressor/utils/tools/sys.py
  function _get_visible_gpu_capacity_list (line 10) | def _get_visible_gpu_capacity_list() -> list[int]:
  function _get_ram_capacity (line 19) | def _get_ram_capacity() -> int:
  function get_max_memory_map (line 28) | def get_max_memory_map(ratio: float = 0.9) -> dict[str, str]:
Condensed preview — 229 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (3,996K chars).
[
  {
    "path": ".gitignore",
    "chars": 3182,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11401,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 18617,
    "preview": "<p align=\"center\">\n<img src=\"assets/deepcompressor.png\" alt=\"DeepCompressor Logo\" width=\"450\">\n</p>\n\n<h2><p align=\"cente"
  },
  {
    "path": "assets/diffusion/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "assets/llm/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/__init__.py",
    "chars": 47,
    "preview": "from .version import __version__  # noqa: F401\n"
  },
  {
    "path": "deepcompressor/app/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/app/diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/app/diffusion/cache/__init__.py",
    "chars": 71,
    "preview": "from .config import DiffusionPtqCacheConfig, DiffusionQuantCacheConfig\n"
  },
  {
    "path": "deepcompressor/app/diffusion/cache/config.py",
    "chars": 2403,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM quantization cache configuration.\"\"\"\n\nimport functools\nimport re\nimport typing as tp\nfrom"
  },
  {
    "path": "deepcompressor/app/diffusion/config.py",
    "chars": 8440,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Top-level config of post-training quantization for a diffusion model.\"\"\"\n\nimport os\nfrom data"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/__init__.py",
    "chars": 138,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .base import DiffusionDataset\nfrom .calib import DiffusionCalibCacheLoader, DiffusionCalib"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/base.py",
    "chars": 2623,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Dataset for diffusion models.\"\"\"\n\nimport os\nimport random\nimport typing as tp\n\nimport numpy a"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/calib.py",
    "chars": 15238,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Calibration dataset for diffusion models.\"\"\"\n\nimport random\nimport typing as tp\nfrom collecti"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/collect/calib.py",
    "chars": 5657,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Collect calibration dataset.\"\"\"\n\nimport os\nfrom dataclasses import dataclass\n\nimport datasets"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/collect/utils.py",
    "chars": 3293,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Common utilities for collecting data.\"\"\"\n\nimport inspect\nimport typing as tp\n\nimport torch\nim"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/COCO/COCO.py",
    "chars": 7724,
    "preview": "# coding=utf-8\n# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.\n#\n# License"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/COCO/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/DCI/DCI.py",
    "chars": 4032,
    "preview": "import os\nimport random\n\nimport datasets\nimport yaml\nfrom PIL import Image\n\n_CITATION = \"\"\"\\\n@InProceedings{Urbanek_2024"
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/DCI/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/MJHQ/MJHQ.py",
    "chars": 3956,
    "preview": "import json\nimport os\nimport random\n\nimport datasets\nfrom PIL import Image\n\n_CITATION = \"\"\"\\\n@misc{li2024playground,\n   "
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/MJHQ/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/__init__.py",
    "chars": 2624,
    "preview": "import os\nimport random\n\nimport datasets\nimport yaml\n\n__all__ = [\"get_dataset\"]\n\n\ndef load_dataset_yaml(meta_path: str, "
  },
  {
    "path": "deepcompressor/app/diffusion/dataset/data/dump.py",
    "chars": 3905,
    "preview": "import argparse\nimport os\n\nimport yaml\nfrom tqdm import tqdm\n\nfrom ...utils import get_control\nfrom . import get_dataset"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/__init__.py",
    "chars": 65,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .config import DiffusionEvalConfig\n"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/config.py",
    "chars": 9873,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Diffusion model evaluation.\"\"\"\n\nimport logging\nimport os\nimport typing as tp\nfrom dataclasses"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/__init__.py",
    "chars": 4543,
    "preview": "import logging\nimport os\n\nfrom deepcompressor.app.diffusion.dataset.data import get_dataset\n\nfrom .fid import compute_fi"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/fid.py",
    "chars": 4298,
    "preview": "import os\nfrom datetime import datetime\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom cleanfid import fid\nfro"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/image_reward.py",
    "chars": 889,
    "preview": "import os\n\nimport datasets\nimport torch\nfrom tqdm import tqdm\n\n__all__ = [\"compute_image_reward\"]\n\n\ndef compute_image_re"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/multimodal.py",
    "chars": 2573,
    "preview": "import os\n\nimport datasets\nimport numpy as np\nimport torch\nimport torchmetrics\nimport torchvision\nfrom PIL import Image\n"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/run.py",
    "chars": 940,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Evaluate generated images or videos using the specified metrics.\"\"\"\n\nimport json\nimport os\n\nf"
  },
  {
    "path": "deepcompressor/app/diffusion/eval/metrics/similarity.py",
    "chars": 3943,
    "preview": "import os\n\nimport datasets\nimport torch\nimport torchmetrics\nimport torchvision\nfrom PIL import Image\nfrom torch.utils im"
  },
  {
    "path": "deepcompressor/app/diffusion/nn/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/app/diffusion/nn/attention.py",
    "chars": 7213,
    "preview": "# -*- coding: utf-8 -*-\n\nimport typing as tp\n\nimport diffusers\nimport packaging.version\nimport torch\nimport torch.nn as "
  },
  {
    "path": "deepcompressor/app/diffusion/nn/patch.py",
    "chars": 6069,
    "preview": "import torch.nn as nn\nfrom diffusers.models.attention_processor import Attention\nfrom diffusers.models.transformers.tran"
  },
  {
    "path": "deepcompressor/app/diffusion/nn/struct.py",
    "chars": 87134,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for Diffusion Models.\"\"\"\n\nimport enum\nimport typing as tp\nfrom abc import a"
  },
  {
    "path": "deepcompressor/app/diffusion/pipeline/__init__.py",
    "chars": 69,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .config import DiffusionPipelineConfig\n"
  },
  {
    "path": "deepcompressor/app/diffusion/pipeline/config.py",
    "chars": 18410,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Diffusion pipeline configuration module.\"\"\"\n\nimport gc\nimport typing as tp\nfrom dataclasses i"
  },
  {
    "path": "deepcompressor/app/diffusion/ptq.py",
    "chars": 17924,
    "preview": "import gc\nimport json\nimport os\nimport pprint\nimport traceback\n\nimport torch\nfrom diffusers import DiffusionPipeline\n\nfr"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/__init__.py",
    "chars": 382,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .activation import quantize_diffusion_activations\nfrom .config import DiffusionQuantCacheC"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/activation.py",
    "chars": 13170,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Diffusion model activation quantization calibration module.\"\"\"\n\nimport gc\nimport typing as tp"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/config.py",
    "chars": 24884,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization config.\"\"\"\n\nimport os\nfrom dataclasses import dataclass, field\n\nimport torch\nfro"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/quantizer/__init__.py",
    "chars": 154,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .config import DiffusionModuleQuantizerConfig\nfrom .quantizer import DiffusionActivationQu"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/quantizer/config.py",
    "chars": 17573,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantizatizer config.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nimport torch\nfrom omnicon"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/quantizer/quantizer.py",
    "chars": 14274,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Tensor Quantizer module.\"\"\"\n\nimport typing as tp\nfrom dataclasses import dataclass, field\n\nim"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/rotate.py",
    "chars": 6221,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Large Language Model Rotation module.\"\"\"\n\nimport gc\n\nimport torch\n\nfrom deepcompressor.calib."
  },
  {
    "path": "deepcompressor/app/diffusion/quant/smooth.py",
    "chars": 30633,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Diffusion smooth quantization module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn a"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/utils.py",
    "chars": 3895,
    "preview": "import typing as tp\n\nimport torch\nimport torch.nn as nn\n\nfrom ..nn.struct import DiffusionAttentionStruct, DiffusionFeed"
  },
  {
    "path": "deepcompressor/app/diffusion/quant/weight.py",
    "chars": 22087,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Diffusion model weight quantization calibration module.\"\"\"\n\nimport gc\nimport typing as tp\n\nim"
  },
  {
    "path": "deepcompressor/app/diffusion/utils.py",
    "chars": 6090,
    "preview": "import os\nimport random\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom deepcompressor.utils.common import "
  },
  {
    "path": "deepcompressor/app/llm/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/app/llm/cache/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/app/llm/cache/config.py",
    "chars": 1689,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM quantization cache configuration.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nfrom omni"
  },
  {
    "path": "deepcompressor/app/llm/config.py",
    "chars": 4728,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Configurations for evaluating a large language model.\"\"\"\n\nimport os\nimport random\nfrom datacl"
  },
  {
    "path": "deepcompressor/app/llm/eval/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/app/llm/eval/base.py",
    "chars": 694,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Language model evaluator base.\"\"\"\n\nfrom abc import ABC, abstractmethod\n\nfrom transformers imp"
  },
  {
    "path": "deepcompressor/app/llm/eval/config.py",
    "chars": 9090,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Language model evaluation config.\"\"\"\n\nimport random\nimport typing as tp\nfrom dataclasses impo"
  },
  {
    "path": "deepcompressor/app/llm/eval/custom.py",
    "chars": 3516,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Language model customized evaluator.\"\"\"\n\nimport math\n\nimport torch\nimport torch.nn as nn\nfrom"
  },
  {
    "path": "deepcompressor/app/llm/eval/lm_eval.py",
    "chars": 1829,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Language model evaluator using lm_eval.\"\"\"\n\nimport lm_eval\nimport lm_eval.models\nfrom transfo"
  },
  {
    "path": "deepcompressor/app/llm/eval/longbench/__init__.py",
    "chars": 54,
    "preview": "from .eval import LongbenchEvaluator, LongbenchScorer\n"
  },
  {
    "path": "deepcompressor/app/llm/eval/longbench/eval.py",
    "chars": 13148,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Language model evaluator for LongBench.\"\"\"\n\nimport json\nimport os\nimport typing as tp\n\nimport"
  },
  {
    "path": "deepcompressor/app/llm/eval/longbench/metrics.py",
    "chars": 5079,
    "preview": "\"\"\"LongBench metrics.\"\"\"\n\nimport re\nimport string\nfrom collections import Counter\n\nimport jieba\nfrom fuzzywuzzy import f"
  },
  {
    "path": "deepcompressor/app/llm/eval/longbench/task2prompt.json",
    "chars": 4977,
    "preview": "{\n    \"narrativeqa\": \"You are given a story, which can be either a novel or a movie script, and a question. Answer the q"
  },
  {
    "path": "deepcompressor/app/llm/model/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/app/llm/model/config.py",
    "chars": 6553,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Net configurations.\"\"\"\n\nimport typing as tp\nfrom dataclasses import dataclass, field\n\nimport "
  },
  {
    "path": "deepcompressor/app/llm/nn/__init__.py",
    "chars": 109,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .struct import LlmModelStruct, LlmTransformerBlockStruct, LlmTransformerStruct\n"
  },
  {
    "path": "deepcompressor/app/llm/nn/patch.py",
    "chars": 8993,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Llama model patcher.\"\"\"\n\nimport functools\n\nimport torch\nimport torch.nn as nn\nfrom transforme"
  },
  {
    "path": "deepcompressor/app/llm/nn/struct.py",
    "chars": 37512,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for Large Language Models.\"\"\"\n\n# region imports\nimport typing as tp\nfrom da"
  },
  {
    "path": "deepcompressor/app/llm/ptq.py",
    "chars": 17880,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Evaluate a large language model.\"\"\"\n\nimport gc\nimport json\nimport os\nimport pprint\nimport tra"
  },
  {
    "path": "deepcompressor/app/llm/quant/__init__.py",
    "chars": 332,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .activation import quantize_llm_activations\nfrom .config import LlmQuantCacheConfig, LlmQu"
  },
  {
    "path": "deepcompressor/app/llm/quant/activation.py",
    "chars": 11352,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM activation quantization calibration module.\"\"\"\n\nimport gc\nimport typing as tp\n\nimport tor"
  },
  {
    "path": "deepcompressor/app/llm/quant/config.py",
    "chars": 18564,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization config.\"\"\"\n\nimport os\nfrom dataclasses import dataclass, field\n\nimport torch\nfro"
  },
  {
    "path": "deepcompressor/app/llm/quant/dataset.py",
    "chars": 13498,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Functions for collecting calibration dataset for quantization.\"\"\"\n\nimport os\nimport random\nim"
  },
  {
    "path": "deepcompressor/app/llm/quant/quantizer/__init__.py",
    "chars": 136,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .config import LlmModuleQuantizerConfig\nfrom .quantizer import LlmActivationQuantizer, Llm"
  },
  {
    "path": "deepcompressor/app/llm/quant/quantizer/config.py",
    "chars": 10156,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantizatizer config.\"\"\"\n\nimport typing as tp\nfrom dataclasses import dataclass, field\n\nimpor"
  },
  {
    "path": "deepcompressor/app/llm/quant/quantizer/quantizer.py",
    "chars": 12056,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Tensor Quantizer module.\"\"\"\n\nimport typing as tp\nfrom dataclasses import dataclass, field\n\nim"
  },
  {
    "path": "deepcompressor/app/llm/quant/reorder.py",
    "chars": 16797,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM quantization channel reordering module.\"\"\"\n\nimport gc\nimport typing as tp\n\nimport torch\ni"
  },
  {
    "path": "deepcompressor/app/llm/quant/rotate.py",
    "chars": 8660,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Large Language Model Rotation module.\"\"\"\n\nimport gc\n\nimport torch\nfrom tqdm import tqdm\nfrom "
  },
  {
    "path": "deepcompressor/app/llm/quant/smooth.py",
    "chars": 10462,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM smooth quantization module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn as nn\nf"
  },
  {
    "path": "deepcompressor/app/llm/quant/utils.py",
    "chars": 3073,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM quantization utils module.\"\"\"\n\nimport typing as tp\n\nimport torch.nn as nn\n\nfrom ..nn.stru"
  },
  {
    "path": "deepcompressor/app/llm/quant/weight.py",
    "chars": 7685,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"LLM weight quantization calibration module.\"\"\"\n\nimport gc\nimport typing as tp\n\nimport torch\ni"
  },
  {
    "path": "deepcompressor/backend/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/backend/nunchaku/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/backend/nunchaku/convert.py",
    "chars": 18740,
    "preview": "\"\"\"Converts a DeepCompressor state dict to a Nunchaku state dict.\"\"\"\n\nimport argparse\nimport os\n\nimport safetensors.torc"
  },
  {
    "path": "deepcompressor/backend/nunchaku/convert_lora.py",
    "chars": 14185,
    "preview": "\"\"\"Convert LoRA weights to Nunchaku format.\"\"\"\n\nimport argparse\nimport os\n\nimport safetensors\nimport safetensors.torch\ni"
  },
  {
    "path": "deepcompressor/backend/nunchaku/utils.py",
    "chars": 19739,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Nunchaku backend utilities.\"\"\"\n\nimport torch\n\nfrom ..tinychat.utils import convert_to_tinycha"
  },
  {
    "path": "deepcompressor/backend/qserve/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/backend/qserve/convert.py",
    "chars": 9066,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"QServe state dict converter module.\"\"\"\n\nimport argparse\nimport os\n\nimport torch\nimport tqdm\n\n"
  },
  {
    "path": "deepcompressor/backend/qserve/utils.py",
    "chars": 10249,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"QServe backend utilities.\"\"\"\n\nimport torch\n\nfrom ..utils import MmaWeightPackerBase\n\n__all__ "
  },
  {
    "path": "deepcompressor/backend/tinychat/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/backend/tinychat/convert.py",
    "chars": 6941,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"QServe state dict converter module.\"\"\"\n\nimport argparse\nimport os\n\nimport safetensors.torch\ni"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/load.py",
    "chars": 1040,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"TinyChat Extension.\"\"\"\n\nimport os\n\nfrom torch.utils.cpp_extension import load\n\n__all__ = [\"_C"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/pybind.cpp",
    "chars": 368,
    "preview": "#include <pybind11/pybind11.h>\n#include <torch/extension.h>\n#include \"quantization/gemm/gemm_cuda.h\"\n#include \"quantizat"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/dequantize.cuh",
    "chars": 6201,
    "preview": "/*\nModified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/c"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/gemm/gemm_cuda.cu",
    "chars": 55402,
    "preview": "#include <cuda_fp16.h>\n#include <cuda_bf16.h>\n#include \"semaphore.h\"\n#include \"gemm_cuda.h\"\n#include \"../dequantize.cuh\""
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/gemm/gemm_cuda.h",
    "chars": 177,
    "preview": "#include <torch/extension.h>\n\ntorch::Tensor awq_gemm_forward_cuda(\n    torch::Tensor _in_feats,\n    torch::Tensor _kerne"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/gemm/semaphore.h",
    "chars": 3886,
    "preview": "/***************************************************************************************************\n * Copyright (c) 20"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/gemv/gemv_cuda.cu",
    "chars": 11067,
    "preview": "/*\n * Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db23574"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/quantization/gemv/gemv_cuda.h",
    "chars": 252,
    "preview": "#pragma once\n#include <torch/extension.h>\n\ntorch::Tensor awq_gemv_forward_cuda(\n    torch::Tensor _in_feats,\n    torch::"
  },
  {
    "path": "deepcompressor/backend/tinychat/csrc/utils.cuh",
    "chars": 11205,
    "preview": "// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransforme"
  },
  {
    "path": "deepcompressor/backend/tinychat/linear.py",
    "chars": 6593,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"TinyChat Quantized Linear Module\"\"\"\n\nimport warnings\n\nimport torch\nimport torch.nn as nn\n\nfro"
  },
  {
    "path": "deepcompressor/backend/tinychat/utils.py",
    "chars": 4367,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"TinyChat backend utilities.\"\"\"\n\nimport torch\n\nfrom ..utils import ceil_divide\n\n__all__ = [\"ce"
  },
  {
    "path": "deepcompressor/backend/utils.py",
    "chars": 6038,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Backend utilities.\"\"\"\n\nimport typing as tp\n\nimport safetensors\nimport torch\n\n__all__ = [\"ceil"
  },
  {
    "path": "deepcompressor/calib/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/calib/config/__init__.py",
    "chars": 549,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .lowrank import QuantLowRankCalibConfig, SkipBasedQuantLowRankCalibConfig\nfrom .range impo"
  },
  {
    "path": "deepcompressor/calib/config/lowrank.py",
    "chars": 4647,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization SVD calibration configuration.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nfro"
  },
  {
    "path": "deepcompressor/calib/config/range.py",
    "chars": 6681,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization dynamic range calibration configuration.\"\"\"\n\nfrom dataclasses import dataclass\n\n"
  },
  {
    "path": "deepcompressor/calib/config/reorder.py",
    "chars": 5582,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Channel reorder configuration.\"\"\"\n\nimport enum\nfrom dataclasses import dataclass, field\n\nfrom"
  },
  {
    "path": "deepcompressor/calib/config/rotation.py",
    "chars": 3214,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization Rotation configuration.\"\"\"\n\nimport os\nimport typing as tp\nfrom dataclasses impor"
  },
  {
    "path": "deepcompressor/calib/config/search.py",
    "chars": 4780,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization calibrator configurations.\"\"\"\n\nimport enum\nfrom dataclasses import dataclass\n\nfr"
  },
  {
    "path": "deepcompressor/calib/config/smooth.py",
    "chars": 17516,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Smooth quantization configuration.\"\"\"\n\nimport enum\nfrom dataclasses import dataclass, field\n\n"
  },
  {
    "path": "deepcompressor/calib/lowrank.py",
    "chars": 9693,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization SVD calibration module.\"\"\"\n\nfrom dataclasses import _MISSING_TYPE, MISSING\n\nimpo"
  },
  {
    "path": "deepcompressor/calib/metric.py",
    "chars": 7350,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Channel-wise metric calculation module.\"\"\"\n\nimport typing as tp\n\nimport torch\n\nfrom ..data.ut"
  },
  {
    "path": "deepcompressor/calib/range.py",
    "chars": 21119,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization dynamic range calibration.\"\"\"\n\nimport gc\nimport typing as tp\nfrom dataclasses im"
  },
  {
    "path": "deepcompressor/calib/reorder.py",
    "chars": 21619,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Channel reordering module.\"\"\"\n\nimport gc\nimport typing as tp\nfrom dataclasses import _MISSING"
  },
  {
    "path": "deepcompressor/calib/rotate.py",
    "chars": 10238,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Rotation Quantization module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn as nn\n\nfr"
  },
  {
    "path": "deepcompressor/calib/search.py",
    "chars": 51037,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Search-based uantization calibrator module.\"\"\"\n\nimport gc\nimport typing as tp\nfrom abc import"
  },
  {
    "path": "deepcompressor/calib/smooth.py",
    "chars": 47439,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Smooth quantization module.\"\"\"\n\nimport gc\nimport typing as tp\nfrom dataclasses import _MISSIN"
  },
  {
    "path": "deepcompressor/csrc/load.py",
    "chars": 951,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Deepcompressor Extension.\"\"\"\n\nimport os\n\nfrom torch.utils.cpp_extension import load\n\n__all__ "
  },
  {
    "path": "deepcompressor/csrc/pybind.cpp",
    "chars": 356,
    "preview": "#include <pybind11/pybind11.h>\n#include <torch/extension.h>\n\n#include \"quantize/quantize.h\"\n\nPYBIND11_MODULE(TORCH_EXTEN"
  },
  {
    "path": "deepcompressor/csrc/quantize/quantize.cu",
    "chars": 4126,
    "preview": "\n#include <assert.h>\n#include <cuda_bf16.h>\n#include <cuda_fp16.h>\n#include <float.h>\n#include <stdint.h>\n#include <stdi"
  },
  {
    "path": "deepcompressor/csrc/quantize/quantize.h",
    "chars": 322,
    "preview": "#pragma once\n#include <torch/extension.h>\n\ntorch::Tensor round_to_nearest_in_codebook_cuda(torch::Tensor tensor,\n       "
  },
  {
    "path": "deepcompressor/data/__init__.py",
    "chars": 199,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .dtype import QDType, QuantDataType\nfrom .range import DynamicRange, LogQuantRange, QuantR"
  },
  {
    "path": "deepcompressor/data/cache.py",
    "chars": 9875,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Activation cache module.\"\"\"\n\nimport math\nimport typing as tp\nfrom collections import OrderedD"
  },
  {
    "path": "deepcompressor/data/codebook.py",
    "chars": 7616,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Codebook for quantization.\"\"\"\n\nfrom dataclasses import dataclass\n\nimport torch\n\nfrom deepcomp"
  },
  {
    "path": "deepcompressor/data/common.py",
    "chars": 230,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Common uantization data.\"\"\"\n\nimport enum\n\n__all__ = [\"TensorType\"]\n\n\nclass TensorType(enum.En"
  },
  {
    "path": "deepcompressor/data/dtype.py",
    "chars": 16225,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization data type.\"\"\"\n\nimport typing as tp\n\nimport torch\n\nfrom .codebook import Codebook"
  },
  {
    "path": "deepcompressor/data/range.py",
    "chars": 18548,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Dynamic range calculation for quantization.\"\"\"\n\nimport math\nimport typing as tp\nfrom dataclas"
  },
  {
    "path": "deepcompressor/data/scale.py",
    "chars": 4272,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization scale module.\"\"\"\n\nimport typing as tp\n\nimport torch\n\n__all__ = [\"QuantScale\"]\n\n\n"
  },
  {
    "path": "deepcompressor/data/tensor.py",
    "chars": 1252,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantized tensor module.\"\"\"\n\nimport torch\n\nfrom .scale import QuantScale\n\n__all__ = [\"QuantTe"
  },
  {
    "path": "deepcompressor/data/utils/__init__.py",
    "chars": 127,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom . import dtype as DtypeUtils\nfrom . import scale as ScaleUtils\nfrom . import shape as Shap"
  },
  {
    "path": "deepcompressor/data/utils/dtype.py",
    "chars": 3494,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for dtype in quantization.\"\"\"\n\nimport torch\n\nfrom ..dtype import QuantDataT"
  },
  {
    "path": "deepcompressor/data/utils/reshape.py",
    "chars": 4821,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Type hints used in deepcompressor.\"\"\"\n\nimport torch\nimport torch.nn.functional as F\n\n__all__ "
  },
  {
    "path": "deepcompressor/data/utils/scale.py",
    "chars": 1853,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for quantization scale.\"\"\"\n\nimport typing as tp\n\nimport torch\n\nfrom ..dtype"
  },
  {
    "path": "deepcompressor/data/utils/shape.py",
    "chars": 7996,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for shape calulation in quantization.\"\"\"\n\nimport typing as tp\n\nimport torch"
  },
  {
    "path": "deepcompressor/data/zero.py",
    "chars": 224,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Zero-point for quantization.\"\"\"\n\nimport enum\n\n__all__ = [\"ZeroPointDomain\"]\n\n\nclass ZeroPoint"
  },
  {
    "path": "deepcompressor/dataset/__init__.py",
    "chars": 157,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .action import CacheAction, ConcatCacheAction\nfrom .cache import BaseCalibCacheLoader\nfrom"
  },
  {
    "path": "deepcompressor/dataset/action.py",
    "chars": 7987,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Actions for caching inputs and outputs.\"\"\"\n\nimport typing as tp\nfrom abc import ABC, abstract"
  },
  {
    "path": "deepcompressor/dataset/cache.py",
    "chars": 21085,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Caching calibration dataset.\"\"\"\n\nimport functools\nimport gc\nimport typing as tp\nfrom abc impo"
  },
  {
    "path": "deepcompressor/dataset/config.py",
    "chars": 1404,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Configuration for collecting calibration dataset for quantization.\"\"\"\n\nfrom abc import ABC, a"
  },
  {
    "path": "deepcompressor/nn/__init__.py",
    "chars": 24,
    "preview": "# -*- coding: utf-8 -*-\n"
  },
  {
    "path": "deepcompressor/nn/patch/__init__.py",
    "chars": 110,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .conv import *\nfrom .linear import *\nfrom .lowrank import *\nfrom .sdpa import *\n"
  },
  {
    "path": "deepcompressor/nn/patch/conv.py",
    "chars": 6808,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Concat Convolution 2d Module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn as nn\nimp"
  },
  {
    "path": "deepcompressor/nn/patch/linear.py",
    "chars": 4794,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Concat Linear Module.\"\"\"\n\nimport torch\nimport torch.nn as nn\n\n__all__ = [\"ConcatLinear\", \"Shi"
  },
  {
    "path": "deepcompressor/nn/patch/lowrank.py",
    "chars": 4095,
    "preview": "# -*- coding: utf-8 -*-\n\nimport torch\nimport torch.linalg\nimport torch.nn as nn\n\nfrom ...utils.hooks import AccumBranchH"
  },
  {
    "path": "deepcompressor/nn/patch/sdpa.py",
    "chars": 691,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Sparse attention module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn as nn\nimport t"
  },
  {
    "path": "deepcompressor/nn/struct/__init__.py",
    "chars": 65,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .attn import *\nfrom .base import *\n"
  },
  {
    "path": "deepcompressor/nn/struct/attn.py",
    "chars": 32134,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Transformer and attention module struct.\"\"\"\n\nimport typing as tp\nfrom abc import abstractmeth"
  },
  {
    "path": "deepcompressor/nn/struct/base.py",
    "chars": 5935,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for Module Struct.\"\"\"\n\nimport types\nimport typing as tp\nfrom abc import ABC"
  },
  {
    "path": "deepcompressor/quantizer/__init__.py",
    "chars": 58,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .processor import Quantizer\n"
  },
  {
    "path": "deepcompressor/quantizer/config/__init__.py",
    "chars": 249,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .base import BaseQuantizerConfig, DecomposedQuantizerConfig, ProgressiveQuantizerConfig, Q"
  },
  {
    "path": "deepcompressor/quantizer/config/base.py",
    "chars": 15332,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization kernel config.\"\"\"\n\nimport typing as tp\nfrom abc import abstractmethod\nfrom datac"
  },
  {
    "path": "deepcompressor/quantizer/config/kernel.py",
    "chars": 5232,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantizatizer kernel configurations.\"\"\"\n\nfrom abc import ABC, abstractmethod\nfrom dataclasses"
  },
  {
    "path": "deepcompressor/quantizer/config/lowrank.py",
    "chars": 1358,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom dataclasses import dataclass\n\nfrom omniconfig import configclass\n\nfrom ...utils.common imp"
  },
  {
    "path": "deepcompressor/quantizer/impl/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "deepcompressor/quantizer/impl/base.py",
    "chars": 14507,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantizer.\"\"\"\n\nimport typing as tp\nfrom dataclasses import dataclass, field\n\nimport torch\n\nfr"
  },
  {
    "path": "deepcompressor/quantizer/impl/info.py",
    "chars": 6728,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization information class.\"\"\"\n\nfrom dataclasses import dataclass, field\n\nimport torch\n\nf"
  },
  {
    "path": "deepcompressor/quantizer/impl/scale.py",
    "chars": 11502,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantization scale module.\"\"\"\n\nimport math\nimport typing as tp\nfrom dataclasses import datacl"
  },
  {
    "path": "deepcompressor/quantizer/impl/simple.py",
    "chars": 2459,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Simple quantization functions.\"\"\"\n\nimport torch\n\nfrom ...data.dtype import QuantDataType\nfrom"
  },
  {
    "path": "deepcompressor/quantizer/impl/ste.py",
    "chars": 779,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Simple quantization functions.\"\"\"\n\nimport typing as tp\n\nimport torch\n\n__all__ = [\"ste\"]\n\n\ncla"
  },
  {
    "path": "deepcompressor/quantizer/kernel/__init__.py",
    "chars": 137,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .gptq import QuantGptqConfig, QuantGptqKernel, gptq_quantize\nfrom .rtn import QuantRtnKern"
  },
  {
    "path": "deepcompressor/quantizer/kernel/gptq.py",
    "chars": 12177,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"GPTQ Quantization kernel.\"\"\"\n\nimport gc\nimport math\nfrom dataclasses import dataclass\n\nimport"
  },
  {
    "path": "deepcompressor/quantizer/kernel/rtn.py",
    "chars": 3786,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Round-to-nearest (RTN) quantization module.\"\"\"\n\nimport torch\n\nfrom ...data.dtype import Quant"
  },
  {
    "path": "deepcompressor/quantizer/processor.py",
    "chars": 17147,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Quantizer.\"\"\"\n\nimport typing as tp\nfrom dataclasses import _MISSING_TYPE, MISSING, dataclass\n"
  },
  {
    "path": "deepcompressor/utils/__init__.py",
    "chars": 68,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .common import *\nfrom .patch import *\n"
  },
  {
    "path": "deepcompressor/utils/common.py",
    "chars": 6556,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Common utilities.\"\"\"\n\nimport typing as tp\n\nimport numpy as np\nimport torch\n\n__all__ = [\n    \""
  },
  {
    "path": "deepcompressor/utils/config/__init__.py",
    "chars": 85,
    "preview": "from .base import EnableConfig, IncludeBasedConfig, KeyEnableConfig, SkipBasedConfig\n"
  },
  {
    "path": "deepcompressor/utils/config/base.py",
    "chars": 7204,
    "preview": "# -*- coding: utf-8 -*-\n\nimport typing as tp\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, fiel"
  },
  {
    "path": "deepcompressor/utils/config/model.py",
    "chars": 1619,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Net configurations.\"\"\"\n\nimport os\nimport typing as tp\nfrom abc import ABC, abstractmethod\nfro"
  },
  {
    "path": "deepcompressor/utils/config/output.py",
    "chars": 3874,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Output configuration.\"\"\"\n\nimport os\nfrom dataclasses import dataclass, field\nfrom datetime im"
  },
  {
    "path": "deepcompressor/utils/config/path.py",
    "chars": 2690,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Path configuration.\"\"\"\n\nimport os\nimport typing as tp\n\nfrom ..dataclass import get_fields\n\n__"
  },
  {
    "path": "deepcompressor/utils/dataclass.py",
    "chars": 1066,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Dataclass utilities.\"\"\"\n\nfrom dataclasses import _FIELD, _FIELD_CLASSVAR, _FIELD_INITVAR, _FI"
  },
  {
    "path": "deepcompressor/utils/hooks/__init__.py",
    "chars": 331,
    "preview": "from .branch import AccumBranchHook\nfrom .hook import EarlyStopException, EarlyStopHook, Hook, IOHook\nfrom .packager imp"
  },
  {
    "path": "deepcompressor/utils/hooks/branch.py",
    "chars": 2256,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Branch hook module.\"\"\"\n\nimport typing as tp\n\nimport torch\nimport torch.nn as nn\n\nfrom .hook i"
  },
  {
    "path": "deepcompressor/utils/hooks/hook.py",
    "chars": 7187,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"nn.Module Hook.\"\"\"\n\nimport typing as tp\nfrom collections import defaultdict\n\nimport torch\nimp"
  },
  {
    "path": "deepcompressor/utils/hooks/packager.py",
    "chars": 9943,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Packagers for input and output tensors in hooks.\"\"\"\n\nimport functools\nimport inspect\nimport t"
  },
  {
    "path": "deepcompressor/utils/hooks/processor.py",
    "chars": 3139,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Tensor processor.\"\"\"\n\nimport abc\nimport typing as tp\n\nimport torch\nimport torch.ao.quantizati"
  },
  {
    "path": "deepcompressor/utils/math/__init__.py",
    "chars": 75,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom .functional import *\nfrom .hadamard import *\n"
  },
  {
    "path": "deepcompressor/utils/math/functional.py",
    "chars": 740,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Math utility functions.\"\"\"\n\nimport torch\n\n__all__ = [\"is_pow2\", \"root_\"]\n\n\ndef is_pow2(n: int"
  },
  {
    "path": "deepcompressor/utils/math/hadamard.py",
    "chars": 2321799,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Utility functions for quantization hadamard transformation.\"\"\"\n\nimport typing as tp\n\nimport s"
  },
  {
    "path": "deepcompressor/utils/patch.py",
    "chars": 2229,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Monkey-patching utilities.\"\"\"\n\nimport copy\nimport functools\nimport types\nimport typing\n\nimpor"
  },
  {
    "path": "deepcompressor/utils/tools/__init__.py",
    "chars": 52,
    "preview": "# -*- coding: utf-8 -*-\n\nfrom . import logging, sys\n"
  },
  {
    "path": "deepcompressor/utils/tools/logging.py",
    "chars": 6658,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Logging tools.\"\"\"\n\nimport logging\nimport sys\nimport typing as tp\n\nfrom tqdm.contrib.logging i"
  },
  {
    "path": "deepcompressor/utils/tools/sys.py",
    "chars": 1212,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"System tools.\"\"\"\n\nimport psutil\nimport torch\n\n__all__ = [\"get_max_memory_map\"]\n\n\ndef _get_vis"
  },
  {
    "path": "deepcompressor/version.py",
    "chars": 74,
    "preview": "# -*- coding: utf-8 -*-\n\"\"\"Version information.\"\"\"\n\n__version__ = \"0.0.2\"\n"
  },
  {
    "path": "environment.yml",
    "chars": 85,
    "preview": "channels:\n  - defaults\ndependencies:\n  - python=3.12\n  - pip\n  - pip:\n      - poetry\n"
  },
  {
    "path": "examples/diffusion/.gitignore",
    "chars": 126,
    "preview": ".tmp\n.tmp/\nbaselines\nbaselines/\nbenchmarks\nbenchmarks/\ncaches\ncaches/\ndatasets\ndatasets/\nvisualize/runs\nvisualize/runs/\n"
  },
  {
    "path": "examples/diffusion/README.md",
    "chars": 12905,
    "preview": "# SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models\n\n[[Website](https://hanlab.mit.edu/proj"
  },
  {
    "path": "examples/diffusion/configs/__default__.yaml",
    "chars": 2439,
    "preview": "seed: 12345\nenable_cache: true\ncache:\n  root: runs\noutput:\n  root: runs\n  dirname: default\npipeline:\n  dtype: torch.floa"
  },
  {
    "path": "examples/diffusion/configs/collect/qdiff.yaml",
    "chars": 98,
    "preview": "collect:\n  root: datasets\n  dataset_name: qdiff\n  data_path: prompts/qdiff.yaml\n  num_samples: 128"
  },
  {
    "path": "examples/diffusion/configs/lora/__default__.yaml",
    "chars": 46,
    "preview": "pipeline:\n  enable_lora: true\nskip_eval: true\n"
  },
  {
    "path": "examples/diffusion/configs/lora/flux.1-dev/anime.yaml",
    "chars": 326,
    "preview": "# https://huggingface.co/alvdansen/sonny-anime-fixed\n# alvdansen/sonny-anime-fixed\n# separate, rank=16\neval:\n  benchmark"
  },
  {
    "path": "examples/diffusion/configs/lora/flux.1-dev/ghibsky.yaml",
    "chars": 334,
    "preview": "# https://huggingface.co/aleksa-codes/flux-ghibsky-illustration\n# aleksa-codes/flux-ghibsky-illustration\n# separate, ran"
  },
  {
    "path": "examples/diffusion/configs/lora/flux.1-dev/realism.yaml",
    "chars": 333,
    "preview": "# https://huggingface.co/XLabs-AI/flux-RealismLora\n# XLabs-AI/flux-RealismLora\n# qkv fused, rank=16, only joint blocks\ne"
  },
  {
    "path": "examples/diffusion/configs/lora/flux.1-dev/sketch.yaml",
    "chars": 410,
    "preview": "# https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch/tree/main\n# Shakker-Labs/FLUX.1-dev-LoRA-Ch"
  },
  {
    "path": "examples/diffusion/configs/lora/flux.1-dev/yarn.yaml",
    "chars": 337,
    "preview": "# https://huggingface.co/linoyts/yarn_art_Flux_LoRA\n# linoyts/yarn_art_Flux_LoRA\n# separate, rank=4, both joint and sing"
  },
  {
    "path": "examples/diffusion/configs/model/flux.1-dev.yaml",
    "chars": 1181,
    "preview": "pipeline:\n  name: flux.1-dev\n  dtype: torch.bfloat16\neval:\n  num_steps: 50\n  guidance_scale: 3.5\n  protocol: fmeuler{num"
  },
  {
    "path": "examples/diffusion/configs/model/flux.1-schnell.yaml",
    "chars": 1182,
    "preview": "pipeline:\n  name: flux.1-schnell\n  dtype: torch.bfloat16\neval:\n  num_steps: 4\n  guidance_scale: 0\n  protocol: fmeuler{nu"
  },
  {
    "path": "examples/diffusion/configs/model/pixart-sigma.yaml",
    "chars": 921,
    "preview": "pipeline:\n  name: pixart-sigma\neval:\n  num_steps: 20\n  guidance_scale: 4.5\n  protocol: dpm{num_steps}-g{guidance_scale}\n"
  },
  {
    "path": "examples/diffusion/configs/model/sana-1.6b.yaml",
    "chars": 1368,
    "preview": "pipeline:\n  name: sana-1.6b-1024px-bf16-ch5632\n  path: Lawrence-cj/Sana_1600M_1024px_BF16_diffusers_ch5632\n  dtype: torc"
  }
]

// ... and 29 more files (download for full content)

About this extraction

This page contains the full source code of the mit-han-lab/deepcompressor GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 229 files (3.7 MB), approximately 970.0k tokens, and a symbol index with 1198 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!