Full Code of hustvl/ControlAR for AI

main 87e32e974162 cached
184 files
2.6 MB
696.4k tokens
813 symbols
1 requests
Download .txt
Showing preview only (2,791K chars total). Download the full file or copy to clipboard to get everything.
Repository: hustvl/ControlAR
Branch: main
Commit: 87e32e974162
Files: 184
Total size: 2.6 MB

Directory structure:
gitextract_x82jk08x/

├── .gitignore
├── LICENSE
├── README.md
├── autoregressive/
│   ├── models/
│   │   ├── README.md
│   │   ├── dinov2_adapter.py
│   │   ├── generate.py
│   │   ├── gpt.py
│   │   ├── gpt_t2i.py
│   │   └── vit_adapter.py
│   ├── sample/
│   │   ├── sample_c2i.py
│   │   ├── sample_c2i_ddp.py
│   │   ├── sample_t2i.py
│   │   ├── sample_t2i_MR.py
│   │   └── sample_t2i_ddp.py
│   ├── serve/
│   │   ├── README.md
│   │   ├── fake_json/
│   │   │   ├── GPT-3B.json
│   │   │   ├── GPT-B.json
│   │   │   ├── GPT-L.json
│   │   │   ├── GPT-XL.json
│   │   │   └── GPT-XXL.json
│   │   ├── gpt_model.py
│   │   ├── gpu_executor.py
│   │   ├── llm.py
│   │   ├── llm_engine.py
│   │   ├── model_runner.py
│   │   ├── sample_c2i.py
│   │   ├── sampler.py
│   │   └── worker.py
│   ├── test/
│   │   ├── metric.py
│   │   ├── test_c2i.py
│   │   ├── test_ssim.py
│   │   └── test_t2i.py
│   └── train/
│       ├── extract_codes_c2i.py
│       ├── extract_codes_t2i.py
│       ├── extract_file_ade.py
│       ├── extract_file_cocostuff.py
│       ├── extract_file_imagenet.py
│       ├── extract_file_multigen.py
│       ├── train_c2i.py
│       ├── train_c2i_canny.py
│       ├── train_c2i_depth.py
│       ├── train_c2i_fsdp.py
│       ├── train_t2i.py
│       ├── train_t2i_canny.py
│       ├── train_t2i_depth.py
│       ├── train_t2i_depth_multiscale.py
│       ├── train_t2i_hed.py
│       ├── train_t2i_hed_multiscale.py
│       ├── train_t2i_lineart.py
│       ├── train_t2i_lineart_multiscale.py
│       ├── train_t2i_seg.py
│       └── train_t2i_seg_multiscale.py
├── condition/
│   ├── README.md
│   ├── canny.py
│   ├── depth.py
│   ├── example/
│   │   └── c2i/
│   │       ├── canny/
│   │       │   ├── 15000.npy
│   │       │   ├── 2312.npy
│   │       │   ├── 48850.npy
│   │       │   └── 650.npy
│   │       └── depth/
│   │           ├── 101.npy
│   │           ├── 10601.npy
│   │           ├── 4351.npy
│   │           └── 48901.npy
│   ├── hed.py
│   ├── lineart.py
│   ├── midas/
│   │   ├── depth.py
│   │   └── midas/
│   │       ├── __init__.py
│   │       ├── base_model.py
│   │       ├── blocks.py
│   │       ├── dpt_depth.py
│   │       ├── midas_net.py
│   │       ├── midas_net_custom.py
│   │       ├── transforms.py
│   │       └── vit.py
│   └── utils.py
├── create_npz.py
├── dataset/
│   ├── augmentation.py
│   ├── build.py
│   ├── coco.py
│   ├── imagenet.py
│   ├── openimage.py
│   ├── pexels.py
│   ├── t2i.py
│   ├── t2i_control.py
│   └── utils.py
├── demo/
│   ├── app.py
│   ├── app_depth.py
│   ├── app_edge.py
│   └── model.py
├── evaluations/
│   ├── ade20k_mIoU.py
│   ├── c2i/
│   │   ├── README.md
│   │   └── evaluator.py
│   ├── canny_f1score.py
│   ├── clean_fid.py
│   ├── cocostuff_mIoU.py
│   ├── depth_rmse.py
│   ├── hed_ssim.py
│   ├── lineart_ssim.py
│   └── t2i/
│       ├── PartiPrompts.tsv
│       ├── README.md
│       ├── coco_captions.csv
│       └── evaluation.py
├── language/
│   ├── README.md
│   ├── extract_t5_feature.py
│   └── t5.py
├── requirements.txt
├── scripts/
│   ├── autoregressive/
│   │   ├── extract_codes_c2i.sh
│   │   ├── extract_file_ade.sh
│   │   ├── extract_file_cocostuff.sh
│   │   ├── extract_file_imagenet.sh
│   │   ├── extract_file_multigen.sh
│   │   ├── sample_c2i.sh
│   │   ├── sample_t2i_coco.sh
│   │   ├── sample_t2i_parti.sh
│   │   ├── test_c2i.sh
│   │   ├── test_t2i.sh
│   │   ├── train_c2i.sh
│   │   ├── train_c2i_canny.sh
│   │   ├── train_c2i_depth.sh
│   │   ├── train_c2i_fsdp.sh
│   │   ├── train_t2i_canny.sh
│   │   ├── train_t2i_depth.sh
│   │   ├── train_t2i_depth_multiscale.sh
│   │   ├── train_t2i_hed.sh
│   │   ├── train_t2i_hed_multiscale.sh
│   │   ├── train_t2i_lineart.sh
│   │   ├── train_t2i_lineart_multiscale.sh
│   │   ├── train_t2i_seg.sh
│   │   ├── train_t2i_seg_multiscale.sh
│   │   ├── train_t2i_stage1.sh
│   │   └── train_t2i_stage2.sh
│   ├── language/
│   │   ├── extract_flan_t5_feat_laion_coco_stage1.sh
│   │   ├── extract_flan_t5_feat_stage2.sh
│   │   └── extract_flan_t5_feat_trunc_stage2.sh
│   └── tokenizer/
│       ├── reconstruction_consistency_decoder.sh
│       ├── reconstruction_vae.sh
│       ├── reconstruction_vq.sh
│       ├── reconstruction_vqgan.sh
│       ├── train_vq.sh
│       ├── train_vq_finetune.sh
│       ├── train_vq_finetune_continue.sh
│       └── val.sh
├── tokenizer/
│   ├── consistencydecoder/
│   │   ├── README.md
│   │   ├── cd_demo.py
│   │   └── reconstruction_cd_ddp.py
│   ├── tokenizer_image/
│   │   ├── cache/
│   │   │   └── vgg.pth
│   │   ├── discriminator.py
│   │   ├── discriminator_patchgan.py
│   │   ├── discriminator_stylegan.py
│   │   ├── lpips.py
│   │   ├── reconstruction_vq_ddp.py
│   │   ├── vq_demo.py
│   │   ├── vq_loss.py
│   │   ├── vq_model.py
│   │   ├── vq_model_hf.py
│   │   └── vq_train.py
│   ├── vae/
│   │   ├── README.md
│   │   ├── reconstruction_vae_ddp.py
│   │   └── sd_vae_demo.py
│   ├── validation/
│   │   └── val_ddp.py
│   └── vqgan/
│       ├── README.md
│       ├── configs/
│       │   ├── vqgan_imagenet_f16_1024.yaml
│       │   ├── vqgan_imagenet_f16_16384.yaml
│       │   ├── vqgan_openimage_f8_16384.yaml
│       │   └── vqgan_openimage_f8_256.yaml
│       ├── layer.py
│       ├── model.py
│       ├── quantize.py
│       ├── reconstruction_vqgan_ddp.py
│       └── taming_vqgan_demo.py
├── tools/
│   ├── check_image_codes.py
│   ├── convert_pytorch_lightning_to_torch.py
│   ├── draw_figure.py
│   ├── imagenet_en_cn.py
│   ├── openimage_json.py
│   ├── push_gpt_to_hf.py
│   └── push_vae_to_hf.py
└── utils/
    ├── data.py
    ├── deepspeed.py
    ├── distributed.py
    ├── drop_path.py
    ├── ema.py
    ├── logger.py
    └── video.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<div align ="center">
<img src="./assets/logo.jpeg" width="20%">
<h1> ControlAR </h1>
<h3> Controllable Image Generation with Autoregressive Models </h3>

Zongming Li<sup>1,\*</sup>, [Tianheng Cheng](https://scholar.google.com/citations?user=PH8rJHYAAAAJ&hl=zh-CN)<sup>1,\*</sup>, [Shoufa Chen](https://shoufachen.com/)<sup>2</sup>, [Peize Sun](https://peizesun.github.io/)<sup>2</sup>, Haocheng Shen<sup>3</sup>,Longjin Ran<sup>3</sup>, Xiaoxin Chen<sup>3</sup>, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu)<sup>1</sup>, [Xinggang Wang](https://xwcv.github.io/)<sup>1,📧</sup>

<sup>1</sup> Huazhong University of Science and Technology,
<sup>2</sup> The University of Hong Kong
<sup>3</sup> vivo AI Lab

<b>ICLR 2025</b>

(\* equal contribution, 📧 corresponding author)

[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2410.02705)
[![demo](https://img.shields.io/badge/Demo-🤗-orange)](https://huggingface.co/spaces/wondervictor/ControlAR)
[![checkpoints](https://img.shields.io/badge/HuggingFace-🤗-green)](https://huggingface.co/wondervictor/ControlAR)

</div>


<div align="center">
<img src="./assets/vis.png">
</div>


## News
`[2025-01-23]:` Our ControlAR has been accepted by ICLR 2025 🚀 !\
`[2024-12-12]:` We introduce a control strength factor, employ a larger control encoder(dinov2-base), and optimize text alignment capabilities along with generation diversity. New model weight: depth_base.safetensors and edge_base.safetensors. The edge_base.safetensors can handle three types of edges, including Canny, HED, and Lineart.\
`[2024-10-31]:` The code and models have been released!\
`[2024-10-04]:` We have released the [technical report of ControlAR](https://arxiv.org/abs/2410.02705). Code, models, and demos are coming soon!


## Highlights

* ControlAR explores an effective yet simple *conditional decoding* strategy for adding spatial controls to autoregressive models, e.g., [LlamaGen](https://github.com/FoundationVision/LlamaGen), from a sequence perspective.

* ControlAR supports *arbitrary-resolution* image generation with autoregressive models without hand-crafted special tokens or resolution-aware prompts.

## TODO

- [x] release code & models.
- [x] release demo code and HuggingFace demo: [HuggingFace Spaces 🤗](https://huggingface.co/spaces/wondervictor/ControlAR)


## Results

We provide both quantitative and qualitative comparisons with diffusion-based methods in the technical report! 

<div align="center">
<img src="./assets/comparison.png">
</div>


## Models

We released checkpoints of text-to-image ControlAR on different controls and settings, *i.e.* arbitrary-resolution generation.

| AR Model | Type | Control encoder | Control | Arbitrary-Resolution | Checkpoint |
| :--------| :--: | :-------------: | :-----: | :------------------: | :--------: |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-small | Canny Edge | ✅ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/canny_MR.safetensors) |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-small | Depth | ✅ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/depth_MR.safetensors) |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-small | HED Edge | ❌ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/hed.safetensors) |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-small | Seg. Mask | ❌ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/seg_cocostuff.safetensors) |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-base | Edge (Canny, Hed, Lineart) | ❌ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/edge_base.safetensors) |
| [LlamaGen-XL](https://github.com/FoundationVision/LlamaGen#-text-conditional-image-generation) | t2i | DINOv2-base | Depth | ❌ | [ckpt](https://huggingface.co/wondervictor/ControlAR/blob/main/depth_base.safetensors) |



## Getting Started

### Installation

```bash
conda create -n ControlAR python=3.10
git clone https://github.com/hustvl/ControlAR.git
cd ControlAR
pip install torch==2.1.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
pip3 install -U openmim 
mim install mmengine 
mim install "mmcv==2.1.0"
pip3 install "mmsegmentation>=1.0.0"
pip3 install mmdet
git clone https://github.com/open-mmlab/mmsegmentation.git
```

### Pretrained Checkpoints for ControlAR

|tokenizer| text encoder |LlamaGen-B|LlamaGen-L|LlamaGen-XL|
|:-------:|:------------:|:--------:|:--------:|:---------:|
|[vq_ds16_t2i.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt)|[flan-t5-xl](https://huggingface.co/google/flan-t5-xl)|[c2i_B_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_B_256.pt)|[c2i_L_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_L_256.pt)|[t2i_XL_512.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/t2i_XL_stage2_512.pt)|

We recommend storing them in the following structures:
```
|---checkpoints
      |---t2i
            |---canny/canny_MR.safetensors
            |---hed/hed.safetensors
            |---depth/depth_MR.safetensors
            |---seg/seg_cocostuff.safetensors
            |---edge_base.safetensors
            |---depth_base.safetensors
      |---t5-ckpt
            |---flan-t5-xl
                  |---config.json
                  |---pytorch_model-00001-of-00002.bin
                  |---pytorch_model-00002-of-00002.bin
                  |---pytorch_model.bin.index.json
                  |---tokenizer.json
      |---vq
            |---vq_ds16_c2i.pt
            |---vq_ds16_t2i.pt
      |---llamagen (Only necessary for training)
            |---c2i_B_256.pt
            |---c2i_L_256.pt
            |---t2i_XL_stage2_512.pt
```

### Demo

Coming soon...


###  Sample & Generation

#### 1. Class-to-image genetation

```bash
python autoregressive/sample/sample_c2i.py \
--vq-ckpt checkpoints/vq/vq_ds16_c2i.pt \
--gpt-ckpt checkpoints/c2i/canny/LlamaGen-L.pt \
--gpt-model GPT-L --seed 0 --condition-type canny
```

#### 2. Text-to-image generation

*Generate an image using HED edge and text-to-image ControlAR:*

```bash
python autoregressive/sample/sample_t2i.py \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/hed/hed.safetensors \
--gpt-model GPT-XL --image-size 512 \
--condition-type hed --seed 0 --condition-path condition/example/t2i/multigen/eye.png
```
*Generate an image using segmentation mask and text-to-image ControlAR:*

```bash
python autoregressive/sample/sample_t2i.py \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/seg/seg_cocostuff.safetensors \
--gpt-model GPT-XL --image-size 512 \
--condition-type seg --seed 0 --condition-path condition/example/t2i/cocostuff/doll.png \
--prompt 'A stuffed animal wearing a mask and a leash, sitting on a pink blanket'
```

#### 3. Text-to-image generation with adjustable control strength
*Generate an image using depth map and text-to-image ControlAR:*

```bash
python autoregressive/sample/sample_t2i.py \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/depth_base.safetensors \
--gpt-model GPT-XL --image-size 512 \
--condition-type seg --seed 0 --condition-path condition/example/t2i/multigen/bird.jpg \
--prompt 'A bird made of blue crystal' \
--adapter-size base \
--control-strength 0.6
```

*Generate an image using lineart edge and text-to-image ControlAR:*

```bash
python autoregressive/sample/sample_t2i.py \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/edge_base.safetensors \
--gpt-model GPT-XL --image-size 512 \
--condition-type lineart --seed 0 --condition-path condition/example/t2i/multigen/girl.jpg \
--prompt 'A girl with blue hair' \
--adapter-size base \
--control-strength 0.6
```

(you can change lineart to canny_base or hed)


#### 4. Arbitrary-resolution generation

```bash
python3 autoregressive/sample/sample_t2i_MR.py --vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/depth_MR.safetensors --gpt-model GPT-XL --image-size 768 \
--condition-type depth --condition-path condition/example/t2i/multi_resolution/bird.jpg \
--prompt 'colorful bird' --seed 0
```

```bash
python3 autoregressive/sample/sample_t2i_MR.py --vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/canny_MR.safetensors --gpt-model GPT-XL --image-size 768 \
--condition-type canny --condition-path condition/example/t2i/multi_resolution/bird.jpg \
--prompt 'colorful bird' --seed 0
```

### Preparing Datasets
We provide the dataset datails for evaluation and training. If you don't want to train ControlAR, just download the validation splits.

#### 1. Class-to-image
* Download [ImageNet](https://image-net.org/) and save it to `data/imagenet/data`.

#### 2. Text-to-image
* Download [ADE20K with caption](https://huggingface.co/datasets/limingcv/Captioned_ADE20K)(~7GB) and save the `.parquet` files to `data/Captioned_ADE20K/data`. 
* Download [COCOStuff with caption](https://huggingface.co/datasets/limingcv/Captioned_COCOStuff)( ~62GB) and save the .parquet files to `data/Captioned_COCOStuff/data`.  
* Download [MultiGen-20M](https://huggingface.co/datasets/limingcv/MultiGen-20M_depth)( ~1.22TB) and save the .parquet files to `data/MultiGen20M/data`.  

#### 3. Preprocessing datasets
To save training time, we adopt the tokenizer to pre-process the images with the text prompts.

* ImageNet
```bash
bash scripts/autoregressive/extract_file_imagenet.sh \
--vq-ckpt checkpoints/vq/vq_ds16_c2i.pt \
--data-path data/imagenet/data/val \
--code-path data/imagenet/val/imagenet_code_c2i_flip_ten_crop \
--ten-crop --crop-range 1.1 --image-size 256
```
* ADE20k
```sh
bash scripts/autoregressive/extract_file_ade.sh \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--data-path data/Captioned_ADE20K/data --code-path data/Captioned_ADE20K/val \
--ten-crop --crop-range 1.1 --image-size 512 --split validation
```
* COCOStuff
```bash
bash scripts/autoregressive/extract_file_cocostuff.sh \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--data-path data/Captioned_COCOStuff/data --code-path data/Captioned_COCOStuff/val \
--ten-crop --crop-range 1.1 --image-size 512 --split validation
```
* MultiGen
```bash
bash scripts/autoregressive/extract_file_multigen.sh \
--vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--data-path data/MultiGen20M/data --code-path data/MultiGen20M/val \
--ten-crop --crop-range 1.1 --image-size 512 --split validation
```

### Testing and Evaluation

#### 1. Class-to-image generation on ImageNet

```bash
bash scripts/autoregressive/test_c2i.sh \
--vq-ckpt ./checkpoints/vq/vq_ds16_c2i.pt \
--gpt-ckpt ./checkpoints/c2i/canny/LlamaGen-L.pt \
--code-path /path/imagenet/val/imagenet_code_c2i_flip_ten_crop \
--gpt-model GPT-L --condition-type canny --get-condition-img True \
--sample-dir ./sample --save-image True
```

```bash
python create_npz.py --generated-images ./sample/imagenet/canny
```
Then download imagenet [validation data](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) which contains 10000 images, or you can use the whole validation data as reference data by running [val.sh](scripts/tokenizer/val.sh). 

Calculate the FID score:
```bash
python evaluations/c2i/evaluator.py /path/imagenet/val/FID/VIRTUAL_imagenet256_labeled.npz \
sample/imagenet/canny.npz
```

#### 2. Text-to-image generation on ADE20k

Download Mask2Former([weight](https://download.openmmlab.com/mmsegmentation/v0.5/mask2former/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640/mask2former_swin-l-in22k-384x384-pre_8xb2-160k_ade20k-640x640_20221203_235933-7120c214.pth)) and save it to `evaluations/`.  

Use this command to get 2000 images based on the segmentation mask:

```bash
bash scripts/autoregressive/test_t2i.sh --vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/seg/seg_ade20k.pt \
--code-path data/Captioned_ADE20K/val --gpt-model GPT-XL --image-size 512 \
--sample-dir sample/ade20k --condition-type seg --seed 0
```
Calculate mIoU of the segmentation masks from the generated images:
```sh
python evaluations/ade20k_mIoU.py
```

#### 3. Text-to-image generation on COCOStuff

Download DeepLabV3([weight](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3/deeplabv3_r101-d8_512x512_4x4_320k_coco-stuff164k/deeplabv3_r101-d8_512x512_4x4_320k_coco-stuff164k_20210709_155402-3cbca14d.pth)) and save it to `evaluations/`.

Generate images using segmentation masks as condition controls:
```bash
bash scripts/autoregressive/test_t2i.sh --vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/seg/seg_cocostuff.pt \
--code-path data/Captioned_COCOStuff/val --gpt-model GPT-XL --image-size 512 \
--sample-dir sample/cocostuff --condition-type seg --seed 0
```
Calculate mIoU of the segmentation masks from the generated images:
```bash
python evaluations/cocostuff_mIoU.py
```

#### 4. Text-to-image generation on MultiGen-20M

We adopt **generation with HED edges** as the example:

Generate 5000 images based on the HED edges generated from validation images
```sh
bash scripts/autoregressive/test_t2i.sh --vq-ckpt checkpoints/vq/vq_ds16_t2i.pt \
--gpt-ckpt checkpoints/t2i/hed/hed.safetensors --code-path data/MultiGen20M/val \
--gpt-model GPT-XL --image-size 512 --sample-dir sample/multigen/hed \
--condition-type hed --seed 0
```

Evaluate the conditional consistency (SSIM):
```bash
python evaluations/hed_ssim.py
```
Calculate the FID score:
```bash
python evaluations/clean_fid.py --val-images data/MultiGen20M/val/image --generated-images sample/multigen/hed/visualization
```

### Training ControlAR

#### 1. Class-to-image (Canny)

```bash
bash scripts/autoregressive/train_c2i_canny.sh --cloud-save-path output \
--code-path data/imagenet/train/imagenet_code_c2i_flip_ten_crop \
--image-size 256 --gpt-model GPT-B --gpt-ckpt checkpoints/llamagen/c2i_B_256.pt
```

#### 2. Text-to-image (Canny)

```bash
bash scripts/autoregressive/train_t2i_canny.sh 
```


## Acknowledgments

The development of ControlAR is based on [LlamaGen](https://github.com/FoundationVision/LlamaGen), [ControlNet](https://github.com/lllyasviel/ControlNet), [ControlNet++](https://github.com/liming-ai/ControlNet_Plus_Plus), and [AiM](https://github.com/hp-l33/AiM), and we sincerely thank the contributors for thoese great works!

## Citation
If you find ControlAR is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.

```bibtex
@inproceedings{ControlAR,
      title={ControlAR: Controllable Image Generation with Autoregressive Models}, 
      author={Li, Zongming and Cheng, Tianheng and Chen, Shoufa and Sun, Peize and Shen, Haocheng and Ran, Longjin and Chen, Xiaoxin and Liu, Wenyu and Wang, Xinggang},
      booktitle={International Conference on Learning Representations},
      year={2025}
}
```



================================================
FILE: autoregressive/models/README.md
================================================
Download the vit weight first 

ViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \
Dinov2-small: https://huggingface.co/facebook/dinov2-small \
Dinov2-base: https://huggingface.co/facebook/dinov2-base

Put them here


================================================
FILE: autoregressive/models/dinov2_adapter.py
================================================
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
import torch
import torch.nn as nn


class Dinov2_Adapter(nn.Module):
    def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'):
        super(Dinov2_Adapter, self).__init__()
        print(f"Choose adapter size: {adapter_size}")
        print(f"condition type: {condition_type}")
        self.model = AutoModel.from_pretrained(f'autoregressive/models/dinov2-{adapter_size}')
        self.condition_type = condition_type
    
    def to_patch14(self, input):
        H, W = input.shape[2:]
        new_H = (H // 16) * 14
        new_W = (W // 16) * 14
        if self.condition_type in ['canny', 'seg']:
            output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True)  canny, seg
        else:
            output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed
        return output
        
    def forward(self, x):
        x = self.to_patch14(x)
        x = self.model(x)
        return x.last_hidden_state[:, 1:]


if __name__ == '__main__':
    model = Dinov2_Adapter().cuda()
    inputs = torch.randn(4,3,512,512).cuda()
    outputs = model(inputs)
    print(outputs.shape)

================================================
FILE: autoregressive/models/generate.py
================================================
# Modified from:
#   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
#   DiT:      https://github.com/facebookresearch/DiT/blob/main/models.py
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch._dynamo.config
import torch._inductor.config
import copy
import time
# torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future


### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
def top_k_top_p_filtering(
    logits,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
):
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        # import pdb;pdb.set_trace()
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits


def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):        
    logits = logits[:, -1, :] / max(temperature, 1e-5)
    if top_k > 0 or top_p < 1.0:
        logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    probs = F.softmax(logits, dim=-1)
    # values, indices = torch.max(probs, dim=1, keepdim=True)
    # mask = (probs == values).float()
    # probs = probs * (1 - mask)
    # values, indices = torch.max(probs, dim=1, keepdim=True)
    # mask = (probs == values).float()
    # probs = probs * (1 - mask)
    if sample_logits:
        idx = torch.multinomial(probs, num_samples=1)
    else:
        _, idx = torch.topk(probs, k=1, dim=-1)
    return idx, probs


def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
    logits = logits / max(temperature, 1e-5)
    if top_k > 0 or top_p < 1.0:
        logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, control_strength: float=1, **sampling_kwargs):
    if cfg_scale > 1.0:
        logits, _ = model(None, cond_idx, input_pos, condition=condition, control_strength=control_strength)
        logits_combined = logits
        cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
        logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
    else:
        logits, _ = model(None, cond_idx, input_pos, condition=condition)

    return sample(logits, **sampling_kwargs)[0]


def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor,  **sampling_kwargs):
    assert input_pos.shape[-1] == 1
    if cfg_scale > 1.0:
        x_combined = torch.cat([x, x])
        logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
        logits_combined = logits
        cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) 
        if cfg_flag:
            logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
        else:
            logits = cond_logits
    else:
        logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
    return sample(logits, **sampling_kwargs)


def decode_n_tokens(
    model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, 
    cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
    **sampling_kwargs):
    new_tokens, new_probs = [], []
    cfg_flag = True
    for i in range(num_new_tokens):
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
            if cfg_interval > -1 and i > cfg_interval:
                cfg_flag = False
            next_token, next_prob = decode_one_token(
                model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
            )
            input_pos += 1
            new_tokens.append(next_token.clone())
            new_probs.append(next_prob.clone())
            cur_token = next_token.view(-1, 1)
    
    return new_tokens, new_probs


@torch.no_grad()
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, control_strength=1, **sampling_kwargs):
    if condition is not None:
        condition = model.adapter(condition)
        condition = model.adapter_mlp(condition)
    if model.model_type == 'c2i':
        if cfg_scale > 1.0:
            cond_null = torch.ones_like(cond) * model.num_classes
            cond_combined = torch.cat([cond, cond_null])
            if condition is not None:
                condition_null = torch.zeros_like(condition)
                condition_combined = torch.cat((condition, condition_null), dim=0)
            else:
                condition_combined = None
        else:
            cond_combined = cond
            if condition is not None:
                condition_combined = condition
            else:
                condition_combined = None
        T = 1+condition_token_nums
    elif model.model_type == 't2i':
        if cfg_scale > 1.0:
            cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
            cond_combined = torch.cat([cond, cond_null])
            
            if condition is not None:
                condition_null = torch.zeros_like(condition)
                condition_combined = torch.cat((condition, condition_null), dim=0)
            else:
                condition_combined = None
        else:
            cond_combined = cond
            if condition is not None:
                condition_combined = condition
            else:
                condition_combined = None
        T = cond.shape[1]      
    else:
        raise Exception("please check model type")

    T_new = T + max_new_tokens
    max_seq_length = T_new
    max_batch_size = cond.shape[0]

    device = cond.device
    with torch.device(device):
        max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
        model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
    
    if emb_masks is not None:
        assert emb_masks.shape[0] == max_batch_size
        assert emb_masks.shape[-1] == T
        if cfg_scale > 1.0:
            model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
        else:
            model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)

        eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
        model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
    
    # create an empty tensor of the expected final shape and fill in the current tokens
    seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
    input_pos = torch.arange(0, T, device=device)
    next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength, **sampling_kwargs)
    seq[:, T:T+1] = next_token

    input_pos = torch.tensor([T], device=device, dtype=torch.int)
    generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
    seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
    return seq[:, T:]


================================================
FILE: autoregressive/models/gpt.py
================================================
# Modified from:
#   VQGAN:    https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
#   DiT:      https://github.com/facebookresearch/DiT/blob/main/models.py  
#   nanoGPT:  https://github.com/karpathy/nanoGPT/blob/master/model.py
#   llama:    https://github.com/facebookresearch/llama/blob/main/llama/model.py
#   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
#   PixArt:   https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
from dataclasses import dataclass
from typing import Optional, List

import io
import torch
import torch.nn as nn
from torch.nn import functional as F
from utils.drop_path import DropPath
from autoregressive.models.dinov2_adapter import Dinov2_Adapter
from autoregressive.models.vit_adapter import ViT_Adapter

def get_causal_mask(seq_length):
    mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
    mask = mask.masked_fill(mask, float('-inf')) 
    mask = mask.masked_fill(~mask, float(0.0))
    return mask

def find_multiple(n: int, k: int):
    if n % k == 0:
        return n
    return n + k - (n % k)

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layer: int = 32
    n_head: int = 32
    n_kv_head: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    rope_base: float = 10000
    norm_eps: float = 1e-5
    initializer_range: float = 0.02
    
    token_dropout_p: float = 0.1
    attn_dropout_p: float = 0.0
    resid_dropout_p: float = 0.1
    ffn_dropout_p: float = 0.1
    drop_path_rate: float = 0.0

    num_classes: int = 1000
    caption_dim: int = 2048
    class_dropout_prob: float = 0.1
    model_type: str = 'c2i'

    vocab_size: int = 16384
    cls_token_num: int = 1
    block_size: int = 256
    max_batch_size: int = 32
    max_seq_len: int = 2048

    condition_token_num: int = 256
    image_size: int = 256


#################################################################################
#                      Embedding Layers for Class Labels                        #
#################################################################################
class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels, drop_ids

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels,drop_ids = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels).unsqueeze(1)
        if (train and use_dropout) or (force_drop_ids is not None):
            return embeddings,drop_ids
        else:
            return embeddings


class ConditionEmbedder(nn.Module):
    """
    Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
        super().__init__()
        self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
        self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
        self.uncond_prob = uncond_prob

    def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            if drop_ids is None:
                drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
        else:
            drop_ids = force_drop_ids == 1
        uncond_embedding = torch.zeros_like(caption[0])
        caption = torch.where(drop_ids[:, None, None], uncond_embedding, caption)
        return caption

    def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
        use_dropout = self.uncond_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            caption = self.token_drop(caption, force_drop_ids, drop_ids)
        embeddings = self.cap_proj(caption)
        return embeddings

#################################################################################
#                      Embedding Layers for Text Feature                        #
#################################################################################
class CaptionEmbedder(nn.Module):
    """
    Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
        super().__init__()
        self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
        self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
        self.uncond_prob = uncond_prob

    def token_drop(self, caption, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
        else:
            drop_ids = force_drop_ids == 1
        caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
        return caption

    def forward(self, caption, train, force_drop_ids=None):
        use_dropout = self.uncond_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            caption = self.token_drop(caption, force_drop_ids)
        embeddings = self.cap_proj(caption)
        return embeddings


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
        self.act = nn.GELU(approximate='tanh')
        self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
        
        nn.init.zeros_(self.fc1.weight)
        nn.init.zeros_(self.fc2.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


#################################################################################
#                                  GPT Model                                    #
#################################################################################
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        hidden_dim = 4 * config.dim
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
        hidden_dim = find_multiple(hidden_dim, config.multiple_of)

        self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
        self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)

    def forward(self, x):
        return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2]
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out


class Attention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        assert config.dim % config.n_head == 0
        self.dim = config.dim
        self.head_dim = config.dim // config.n_head
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
        total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim

        # key, query, value projections for all heads, but in a batch
        self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
        self.wo = nn.Linear(config.dim, config.dim, bias=False)
        self.kv_cache = None

        # regularization
        self.attn_dropout_p = config.attn_dropout_p
        self.resid_dropout = nn.Dropout(config.resid_dropout_p)

    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor = None, 
        input_pos: Optional[torch.Tensor] = None, 
        mask: Optional[torch.Tensor] = None
    ):
        bsz, seqlen, _ = x.shape
        kv_size = self.n_kv_head * self.head_dim
        xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

        xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        
        xq = apply_rotary_emb(xq, freqs_cis)
        xk = apply_rotary_emb(xk, freqs_cis)

        xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))

        if self.kv_cache is not None:
            keys, values = self.kv_cache.update(input_pos, xk, xv)
        else:
            keys, values = xk, xv
        keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
        values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)

        output = F.scaled_dot_product_attention(
            xq, keys, values, 
            attn_mask=mask, 
            is_causal=True if mask is None else False, # is_causal=False is for KV cache
            dropout_p=self.attn_dropout_p if self.training else 0)            
        
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

        output = self.resid_dropout(self.wo(output))
        return output


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, drop_path: float):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
        h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
        out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
        return out


class Transformer(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.n_layer = config.n_layer
        self.block_size = config.block_size
        self.num_classes = config.num_classes
        self.model_type = config.model_type
        self.cls_token_num = config.cls_token_num
        self.condition_token_num = config.condition_token_num
        self.layer_internal = config.n_layer // 3
        # self.adapter = Adapter(output_dim=config.dim)
        self.adapter = ViT_Adapter()
        # self.adapter = Deit_Adapter()
        # self.adapter = EVA_Adapter(img_size=256, in_chans=3, embed_dim=384)
        # self.adapter = Dinov2_Adapter(adapter_size='base')
        # self.adapter = EVA_Adapter()
        self.adapter_mlp = MLP(384, config.dim, config.dim)
        # self.adapter_mlp = MLP(768, config.dim, config.dim)
        # self.cross_attention = nn.MultiheadAttention(embed_dim=config.dim, num_heads=8,batch_first=True)
        if self.model_type == 'c2i':
            self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
        elif self.model_type == 't2i':
            self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
        else:
            raise Exception("please check model type")
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.tok_dropout = nn.Dropout(config.token_dropout_p)

        self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.condition_mlp = ConditionEmbedder((config.image_size // 16)**2, config.dim, config.class_dropout_prob, (config.image_size // 16)**2, config.vocab_size)

        self.condition_layers = torch.nn.ModuleList()
        for layer_id in range(3):
            self.condition_layers.append(MLP(config.dim,config.dim,config.dim))

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layer):
            self.layers.append(TransformerBlock(config, dpr[layer_id]))

        # output layer
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.condition_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        # 2d rotary pos embedding
        grid_size = int(self.block_size ** 0.5)
        assert grid_size * grid_size == self.block_size
        self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num+self.condition_token_num)
        
        # KVCache
        self.max_batch_size = -1
        self.max_seq_length = -1

        self.initialize_weights()
        self.condition_token = None
        self.global_token = None
        self.mask = get_causal_mask(256)

    def initialize_weights(self):        
        # Initialize nn.Linear and nn.Embedding
        self.apply(self._init_weights)


    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

    def setup_caches(self, max_batch_size, max_seq_length, dtype):
        # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
        #     return
        head_dim = self.config.dim // self.config.n_head
        max_seq_length = find_multiple(max_seq_length, 8)  # 
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        for b in self.layers:
            b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)

        causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
        self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
        grid_size = int(self.config.block_size ** 0.5)
        assert grid_size * grid_size == self.block_size
        self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num+self.condition_token_num)


    
    def forward(
        self, 
        idx: torch.Tensor, 
        cond_idx: torch.Tensor,  # cond_idx_or_embed
        input_pos:  Optional[torch.Tensor] = None, 
        targets: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        valid: Optional[torch.Tensor] = None,
        condition: Optional[torch.Tensor] = None
    ):
        if idx is not None and cond_idx is not None: # training or naive inference
            cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
            cond_embeddings = cond_embeddings[:,:self.cls_token_num]
            token_embeddings = self.tok_embeddings(idx)
            if condition is not None:
                condition_embeddings = self.adapter(condition)
                condition_embeddings = self.adapter_mlp(condition_embeddings)

                self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
            token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
            h = self.tok_dropout(token_embeddings)
            self.freqs_cis = self.freqs_cis.to(h.device)
        else:
            if cond_idx is not None: # prefill in inference
                token_embeddings = self.cls_embedding(cond_idx, train=self.training)
                token_embeddings = token_embeddings[:,:self.cls_token_num]
                if condition is not None:
                    condition_embeddings = self.condition_mlp(condition.to(torch.bfloat16),train=self.training)
                    self.condition_token = condition_embeddings
            else: # decode_n_tokens(kv cache) in inference
                token_embeddings = self.tok_embeddings(idx)
            bs = token_embeddings.shape[0]
            mask = self.causal_mask[:bs, None, input_pos]
            h = self.tok_dropout(token_embeddings)
            self.freqs_cis = self.freqs_cis
        if self.training:
            freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
        else:
            freqs_cis = self.freqs_cis[input_pos]
        # transformer blocks
        for i, layer in enumerate(self.layers):
            if i%self.layer_internal == 0:
                if self.training:
                    h = h + self.condition_layers[i//self.layer_internal](self.condition_token)
                else:
                    if len(input_pos)>1:
                        h[:,-1:] = h[:,-1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
                    else:
                        h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos])
            h = layer(h, freqs_cis, input_pos, mask)
        # output layers
        h = self.norm(h)
        logits = self.output(h).float()
        
        if self.training:
            logits = logits[:, self.cls_token_num+self.condition_token_num - 1:].contiguous()
        # if we are given some desired targets also calculate the loss
        loss = None
        if valid is not None:
            loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
            valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
            loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
        elif targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss


    def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
        return list(self.layers)



#################################################################################
#                      Rotary Positional Embedding Functions                    #
#################################################################################
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
    return cond_cache 


def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
    # split the dimension into half, one for x and one for y
    half_dim = n_elem // 2
    freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
    t = torch.arange(grid_size, device=freqs.device)
    freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
    freqs_grid = torch.concat([
        freqs[:, None, :].expand(-1, grid_size, -1),
        freqs[None, :, :].expand(grid_size, -1, -1),
    ], dim=-1)  # (grid_size, grid_size, head_dim // 2)
    cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
    cache = cache_grid.flatten(0, 1)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
    return cond_cache 


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
    # x: (bs, seq_len, n_head, head_dim)
    # freqs_cis (seq_len, head_dim // 2, 2)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
    x_out2 = torch.stack([
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)



#################################################################################
#                                GPT Configs                                    #
#################################################################################
### text-conditional
def GPT_7B(**kwargs):
    return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B

def GPT_3B(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B

def GPT_1B(**kwargs):
    return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B

### class-conditional
def GPT_XXXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B

def GPT_XXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B

def GPT_XL(**kwargs):
    return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M

def GPT_L(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M

def GPT_B(**kwargs):
    return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
        

GPT_models = {
    'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
    'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 
}

================================================
FILE: autoregressive/models/gpt_t2i.py
================================================
# Modified from:
#   VQGAN:    https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
#   DiT:      https://github.com/facebookresearch/DiT/blob/main/models.py  
#   nanoGPT:  https://github.com/karpathy/nanoGPT/blob/master/model.py
#   llama:    https://github.com/facebookresearch/llama/blob/main/llama/model.py
#   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
#   PixArt:   https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
from dataclasses import dataclass
from typing import Optional, List


import torch
import torch.nn as nn
from torch.nn import functional as F
from utils.drop_path import DropPath
from autoregressive.models.vit_adapter import ViT_Adapter
from autoregressive.models.dinov2_adapter import Dinov2_Adapter


def get_causal_mask(seq_length):
    mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
    mask = mask.masked_fill(mask, float('-inf'))  
    mask = mask.masked_fill(~mask, float(0.0))  
    return mask

def find_multiple(n: int, k: int):
    if n % k == 0:
        return n
    return n + k - (n % k)

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layer: int = 32
    n_head: int = 32
    n_kv_head: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    rope_base: float = 10000
    norm_eps: float = 1e-5
    initializer_range: float = 0.02
    
    token_dropout_p: float = 0.1
    attn_dropout_p: float = 0.0
    resid_dropout_p: float = 0.1
    ffn_dropout_p: float = 0.1
    drop_path_rate: float = 0.0

    num_classes: int = 1000
    caption_dim: int = 2048
    class_dropout_prob: float = 0.1
    model_type: str = 'c2i'

    vocab_size: int = 16384
    cls_token_num: int = 1
    block_size: int = 256
    max_batch_size: int = 32
    max_seq_len: int = 2048
    adapter_size: str = 'small'
    condition_type: str = 'canny'



#################################################################################
#                      Embedding Layers for Class Labels                        #
#################################################################################
class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels, drop_ids

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels,drop_ids = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels).unsqueeze(1)
        if (train and use_dropout) or (force_drop_ids is not None):
            return embeddings,drop_ids
        else:
            return embeddings


class ConditionEmbedder(nn.Module):
    """
    Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
        super().__init__()
        self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
        self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
        self.uncond_prob = uncond_prob

    def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            if drop_ids is None:
                drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
        else:
            drop_ids = force_drop_ids == 1

        caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption)
        return caption

    def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
        use_dropout = self.uncond_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            caption = self.token_drop(caption, force_drop_ids, drop_ids)
        embeddings = self.cap_proj(caption)
        return embeddings

#################################################################################
#                      Embedding Layers for Text Feature                        #
#################################################################################
class CaptionEmbedder(nn.Module):
    """
    Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
        super().__init__()
        self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
        self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
        self.uncond_prob = uncond_prob

    def token_drop(self, caption, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
        else:
            drop_ids = force_drop_ids == 1
        caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
        return caption, drop_ids

    def forward(self, caption, train, force_drop_ids=None):
        use_dropout = self.uncond_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            caption, drop_ids = self.token_drop(caption, force_drop_ids)
        embeddings = self.cap_proj(caption)
        if (train and use_dropout) or (force_drop_ids is not None):
            return embeddings,drop_ids
        else:
            return embeddings


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
        self.act = nn.GELU(approximate='tanh')
        self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
        
        nn.init.zeros_(self.fc1.weight)
        nn.init.zeros_(self.fc2.weight)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x


#################################################################################
#                                  GPT Model                                    #
#################################################################################
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        hidden_dim = 4 * config.dim
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
        hidden_dim = find_multiple(hidden_dim, config.multiple_of)

        self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
        self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)

    def forward(self, x):
        return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
        super().__init__()
        cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: [S], k_val: [B, H, S, D]
        assert input_pos.shape[0] == k_val.shape[2]
        k_out = self.k_cache
        v_out = self.v_cache
        k_out[:, :, input_pos] = k_val
        v_out[:, :, input_pos] = v_val

        return k_out, v_out


class Attention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        assert config.dim % config.n_head == 0
        self.dim = config.dim
        self.head_dim = config.dim // config.n_head
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
        total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim

        # key, query, value projections for all heads, but in a batch
        self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
        self.wo = nn.Linear(config.dim, config.dim, bias=False)
        self.kv_cache = None

        # regularization
        self.attn_dropout_p = config.attn_dropout_p
        self.resid_dropout = nn.Dropout(config.resid_dropout_p)

    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor = None, 
        input_pos: Optional[torch.Tensor] = None, 
        mask: Optional[torch.Tensor] = None
    ):
        bsz, seqlen, _ = x.shape
        kv_size = self.n_kv_head * self.head_dim
        xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

        xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
        
        xq = apply_rotary_emb(xq, freqs_cis)
        xk = apply_rotary_emb(xk, freqs_cis)

        xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))

        if self.kv_cache is not None:
            keys, values = self.kv_cache.update(input_pos, xk, xv)
        else:
            keys, values = xk, xv
        keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
        values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)

        output = F.scaled_dot_product_attention(
            xq, keys, values, 
            attn_mask=mask, 
            is_causal=True if mask is None else False, # is_causal=False is for KV cache
            dropout_p=self.attn_dropout_p if self.training else 0)            
        
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

        output = self.resid_dropout(self.wo(output))
        return output


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs, drop_path: float):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
        h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
        out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
        return out


class Transformer(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.n_layer = config.n_layer
        self.block_size = config.block_size
        self.num_classes = config.num_classes
        self.model_type = config.model_type
        self.cls_token_num = config.cls_token_num
        self.layer_internal = config.n_layer // 3
        # self.adapter = Adapter(output_dim=768)
        # self.adapter = ViT_Adapter()
        # self.adapter = DeiT_Adapter()
        self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type)
        # self.adapter = EVA_Adapter()
        if config.adapter_size == "small":
            self.adapter_mlp = MLP(384, config.dim, config.dim)
        elif config.adapter_size == 'base':
            self.adapter_mlp = MLP(768, config.dim, config.dim)

        if self.model_type == 'c2i':
            self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
        elif self.model_type == 't2i':
            self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
        else:
            raise Exception("please check model type")
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.tok_dropout = nn.Dropout(config.token_dropout_p)

        self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
        self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size)
        self.condition_layers = torch.nn.ModuleList()
        for layer_id in range(3):
            self.condition_layers.append(MLP(config.dim,config.dim,config.dim))

        # transformer blocks
        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layer):
            self.layers.append(TransformerBlock(config, dpr[layer_id]))

        # output layer
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        # 2d rotary pos embedding
        grid_size = int(self.block_size ** 0.5)
        assert grid_size * grid_size == self.block_size
        self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
        
        # KVCache
        self.max_batch_size = -1
        self.max_seq_length = -1

        self.initialize_weights()
        self.condition_token = None
        self.mask = get_causal_mask(256)
        self.global_token = None

        self.control_strength = 1    

    def initialize_weights(self):        
        # Initialize nn.Linear and nn.Embedding
        self.apply(self._init_weights)

        # Zero-out output layers:
        nn.init.constant_(self.output.weight, 0)

        
        
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

        
    def setup_caches(self, max_batch_size, max_seq_length, dtype):
        # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
        #     return
        head_dim = self.config.dim // self.config.n_head
        max_seq_length = find_multiple(max_seq_length, 8)  # 
        self.max_seq_length = max_seq_length
        self.max_batch_size = max_batch_size
        for b in self.layers:
            b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)

        causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
        self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
        grid_size = int(self.config.block_size ** 0.5)
        assert grid_size * grid_size == self.block_size
        self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)


    
    def forward(
        self, 
        idx: torch.Tensor, 
        cond_idx: torch.Tensor,  # cond_idx_or_embed
        input_pos:  Optional[torch.Tensor] = None, 
        targets: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        valid: Optional[torch.Tensor] = None,
        condition: Optional[torch.Tensor] = None,
        control_strength: Optional[int] = 1
    ):
        if idx is not None and cond_idx is not None: # training or naive inference
            cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
            cond_embeddings = cond_embeddings[:,:self.cls_token_num]
            token_embeddings = self.tok_embeddings(idx)
            if condition is not None:
                condition_embeddings = self.adapter(condition)
                condition_embeddings = self.adapter_mlp(condition_embeddings)
                self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
            token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)

            h = self.tok_dropout(token_embeddings)
            self.freqs_cis = self.freqs_cis.to(h.device)
        else:
            if cond_idx is not None: # prefill in inference
                self.control_strength = control_strength
                token_embeddings = self.cls_embedding(cond_idx, train=self.training)
                token_embeddings = token_embeddings[:,:self.cls_token_num]
                if condition is not None:
                    condition_embeddings = self.condition_mlp(condition, train=self.training)#.to(torch.bfloat16),train=self.training)
                    self.condition_token = condition_embeddings
                    self.condition_token = [self.condition_layers[0](self.condition_token),
                                            self.condition_layers[1](self.condition_token),
                                            self.condition_layers[2](self.condition_token)]
                    
            else: # decode_n_tokens(kv cache) in inference
                token_embeddings = self.tok_embeddings(idx)
            bs = token_embeddings.shape[0]
            mask = self.causal_mask[:bs, None, input_pos]
            h = self.tok_dropout(token_embeddings)
            self.freqs_cis = self.freqs_cis

        if self.training:
            freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
        else:
            freqs_cis = self.freqs_cis[input_pos]
        # transformer blocks
        for i, layer in enumerate(self.layers):
            if i%self.layer_internal == 0:
                if self.training:
                    h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
                else:
                    if len(input_pos)>1:
                        # h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
                        h[:,-1:] = h[:, -1:] + self.control_strength*self.condition_token[i//self.layer_internal][:,0:1]
                    else:
                        # h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
                        h = h + self.control_strength*self.condition_token[i//self.layer_internal][:,input_pos-self.cls_token_num+1]
            h = layer(h, freqs_cis, input_pos, mask)
        # output layers
        h = self.norm(h)
        logits = self.output(h).float()
        
        if self.training:
            logits = logits[:, self.cls_token_num - 1:].contiguous()
        # if we are given some desired targets also calculate the loss
        loss = None
        if valid is not None:
            loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
            valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
            loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
        elif targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))


        return logits, loss


    def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
        return list(self.layers)



#################################################################################
#                      Rotary Positional Embedding Functions                    #
#################################################################################
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
    return cond_cache 


def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
    # split the dimension into half, one for x and one for y
    half_dim = n_elem // 2
    freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
    t = torch.arange(grid_size, device=freqs.device)
    freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
    freqs_grid = torch.concat([
        freqs[:, None, :].expand(-1, grid_size, -1),
        freqs[None, :, :].expand(grid_size, -1, -1),
    ], dim=-1)  # (grid_size, grid_size, head_dim // 2)
    cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
    cache = cache_grid.flatten(0, 1)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
    return cond_cache 


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
    # x: (bs, seq_len, n_head, head_dim)
    # freqs_cis (seq_len, head_dim // 2, 2)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
    x_out2 = torch.stack([
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)



#################################################################################
#                                GPT Configs                                    #
#################################################################################
### text-conditional
def GPT_7B(**kwargs):
    return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B

def GPT_3B(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B

def GPT_1B(**kwargs):
    return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B

### class-conditional
def GPT_XXXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B

def GPT_XXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B

def GPT_XL(**kwargs):
    return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M

def GPT_L(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M

def GPT_B(**kwargs):
    return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
        

GPT_models = {
    'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
    'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 
}


================================================
FILE: autoregressive/models/vit_adapter.py
================================================
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
import torch
import torch.nn as nn


class ViT_Adapter(nn.Module):
    def __init__(self, input_dim=3, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1):
        super(ViT_Adapter, self).__init__()
        self.model = AutoModel.from_pretrained('autoregressive/models/vit-small')
        
    def forward(self, x):
        x = self.model(x,interpolate_pos_encoding=True)
        return x.last_hidden_state[:, 1:]


if __name__ == '__main__':
    model = ViT_Adapter().cuda()
    import pdb;pdb.set_trace()
    print(sum(p.numel() for p in model.parameters()))
    inputs = torch.randn(4,3,512,512).cuda()

    outputs = model(inputs)

    print(outputs.shape)

================================================
FILE: autoregressive/sample/sample_c2i.py
================================================
# Modified from:
#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample.py
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
from torchvision.utils import save_image
import os
import sys
current_directory = os.getcwd()
sys.path.append(current_directory)

from PIL import Image
import time
import argparse
from tokenizer.tokenizer_image.vq_model import VQ_models
from autoregressive.models.gpt import GPT_models
from autoregressive.models.generate import generate
from functools import partial
import torch.nn.functional as F
import numpy as np
import cv2


def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(False)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    # create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim)
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint
    print(f"image tokenizer is loaded")

    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    latent_size = args.image_size // args.downsample_size
    gpt_model = GPT_models[args.gpt_model](
        vocab_size=args.codebook_size,
        block_size=latent_size ** 2,
        num_classes=args.num_classes,
        cls_token_num=args.cls_token_num,
        model_type=args.gpt_type,
        condition_token_num=args.condition_token_nums,
        image_size=args.image_size
    ).to(device=device, dtype=precision)      
    
    _, file_extension = os.path.splitext(args.gpt_ckpt)
    if file_extension.lower() == '.safetensors':
        from safetensors.torch import load_file
        model_weight = load_file(args.gpt_ckpt)
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
    else:
        checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
        if "model" in checkpoint:  # ddp
            model_weight = checkpoint["model"]
        elif "module" in checkpoint: # deepspeed
            model_weight = checkpoint["module"]
        elif "state_dict" in checkpoint:
            model_weight = checkpoint["state_dict"]
        else:
            raise Exception("please check model weight")
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
        del checkpoint
    print(f"gpt model is loaded")

    if args.compile:
        print(f"compiling the model...")
        gpt_model = torch.compile(
            gpt_model,
            mode="reduce-overhead",
            fullgraph=True
        ) # requires PyTorch 2.0 (optional)
    else:
        print(f"no need to compile model in demo") 

    condition_null = None
    if args.condition_type == 'canny':
        sample_list = [650, 2312, 15000, 48850]  # canny
    elif args.condition_type == 'depth':
        sample_list = [101, 4351, 10601, 48901]

    class_labels = [np.load(f"condition/example/c2i/{args.condition_type}/{i}.npy")[0] for i in sample_list]
    condition_imgs = [np.array(Image.open((f"condition/example/c2i/{args.condition_type}/{i}.png")))[None,None,...] for i in sample_list]
    condition_imgs = torch.from_numpy(np.concatenate(condition_imgs, axis=0)).to(device).to(torch.float32)/255
    condition_imgs = 2*(condition_imgs-0.5)
    print(condition_imgs.shape)
    c_indices = torch.tensor(class_labels, device=device)
    qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
    t1 = time.time()

    index_sample = generate(
        gpt_model, c_indices, latent_size ** 2, condition=condition_imgs.repeat(1,3,1,1).to(precision), condition_null=condition_null, condition_token_nums=args.condition_token_nums,
        cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
        temperature=args.temperature, top_k=args.top_k,
        top_p=args.top_p, sample_logits=True, 
        )

    sampling_time = time.time() - t1
    print(f"gpt sampling takes about {sampling_time:.2f} seconds.")    
    
    t2 = time.time()
    samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
    decoder_time = time.time() - t2
    print(f"decoder takes about {decoder_time:.2f} seconds.")
    # Save and display images:
    condition_imgs = condition_imgs.repeat(1,3,1,1)
    samples = torch.cat((condition_imgs[:4], samples[:4]),dim=0)
    save_image(samples, f"sample/example/sample_{args.gpt_type}_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
    parser.add_argument("--gpt-ckpt", type=str, default=None)
    parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
    parser.add_argument("--from-fsdp", action='store_true')
    parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
    parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=4.0)
    parser.add_argument("--cfg-interval", type=float, default=-1)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
    parser.add_argument("--condition-token-nums", type=int, default=0)
    parser.add_argument("--condition-type", type=str, default='canny', choices=['canny', 'depth'])
    args = parser.parse_args()
    main(args)

================================================
FILE: autoregressive/sample/sample_c2i_ddp.py
================================================
# Modified from:
#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.nn.functional as F
import torch.distributed as dist

from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import math
import argparse

from tokenizer.tokenizer_image.vq_model import VQ_models
from autoregressive.models.gpt import GPT_models
from autoregressive.models.generate import generate


def create_npz_from_sample_folder(sample_dir, num=50_000):
    """
    Builds a single .npz file from a folder of .png samples.
    """
    samples = []
    for i in tqdm(range(num), desc="Building .npz file from samples"):
        sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
        sample_np = np.asarray(sample_pil).astype(np.uint8)
        samples.append(sample_np)
    samples = np.stack(samples)
    assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
    npz_path = f"{sample_dir}.npz"
    np.savez(npz_path, arr_0=samples)
    print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
    return npz_path


def main(args):
    # Setup PyTorch:
    assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
    torch.set_grad_enabled(False)

    # Setup DDP:
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    # create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim)
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint

    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    latent_size = args.image_size // args.downsample_size
    gpt_model = GPT_models[args.gpt_model](
        vocab_size=args.codebook_size,
        block_size=latent_size ** 2,
        num_classes=args.num_classes,
        cls_token_num=args.cls_token_num,
        model_type=args.gpt_type,
    ).to(device=device, dtype=precision)
    checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
    if args.from_fsdp: # fsdp
        model_weight = checkpoint
    elif "model" in checkpoint:  # ddp
        model_weight = checkpoint["model"]
    elif "module" in checkpoint: # deepspeed
        model_weight = checkpoint["module"]
    elif "state_dict" in checkpoint:
        model_weight = checkpoint["state_dict"]
    else:
        raise Exception("please check model weight, maybe add --from-fsdp to run command")
    # if 'freqs_cis' in model_weight:
    #     model_weight.pop('freqs_cis')
    gpt_model.load_state_dict(model_weight, strict=False)
    gpt_model.eval()
    del checkpoint

    if args.compile:
        print(f"compiling the model...")
        gpt_model = torch.compile(
            gpt_model,
            mode="reduce-overhead",
            fullgraph=True
        ) # requires PyTorch 2.0 (optional)
    else:
        print(f"no model compile") 

    # Create folder to save samples:
    model_string_name = args.gpt_model.replace("/", "-")
    if args.from_fsdp:
        ckpt_string_name = args.gpt_ckpt.split('/')[-2]
    else:
        ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
    folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-size-{args.image_size_eval}-{args.vq_model}-" \
                  f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
                  f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
    sample_folder_dir = f"{args.sample_dir}/{folder_name}"
    if rank == 0:
        os.makedirs(sample_folder_dir, exist_ok=True)
        print(f"Saving .png samples at {sample_folder_dir}")
    dist.barrier()

    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
    n = args.per_proc_batch_size
    global_batch_size = n * dist.get_world_size()
    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
    total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
    if rank == 0:
        print(f"Total number of images that will be sampled: {total_samples}")
    assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
    samples_needed_this_gpu = int(total_samples // dist.get_world_size())
    assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
    iterations = int(samples_needed_this_gpu // n)
    pbar = range(iterations)
    pbar = tqdm(pbar) if rank == 0 else pbar
    total = 0
    for _ in pbar:
        # Sample inputs:
        c_indices = torch.randint(0, args.num_classes, (n,), device=device)
        qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]

        index_sample = generate(
            gpt_model, c_indices, latent_size ** 2,
            cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
            temperature=args.temperature, top_k=args.top_k,
            top_p=args.top_p, sample_logits=True, 
            )
        
        samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
        if args.image_size_eval != args.image_size:
            samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
        samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
        
        # Save samples to disk as individual .png files
        for i, sample in enumerate(samples):
            index = i * dist.get_world_size() + rank + total
            Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
        total += global_batch_size

    # Make sure all processes have finished saving their samples before attempting to convert to .npz
    dist.barrier()
    if rank == 0:
        create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
        print("Done.")
    dist.barrier()
    dist.destroy_process_group()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
    parser.add_argument("--gpt-ckpt", type=str, default=None)
    parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
    parser.add_argument("--from-fsdp", action='store_true')
    parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=True)
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
    parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
    parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale",  type=float, default=1.5)
    parser.add_argument("--cfg-interval", type=float, default=-1)
    parser.add_argument("--sample-dir", type=str, default="samples")
    parser.add_argument("--per-proc-batch-size", type=int, default=32)
    parser.add_argument("--num-fid-samples", type=int, default=5000)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--top-k", type=int, default=0,help="top-k value to sample with")
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
    args = parser.parse_args()
    main(args)

================================================
FILE: autoregressive/sample/sample_t2i.py
================================================
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from torchvision.utils import save_image

import os
import sys
current_directory = os.getcwd()
sys.path.append(current_directory)
import time
import argparse
from tokenizer.tokenizer_image.vq_model import VQ_models
from language.t5 import T5Embedder
from autoregressive.models.gpt import GPT_models
from autoregressive.models.gpt_t2i import GPT_models
from autoregressive.models.generate import generate
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from dataset.t2i_control import build_t2i_control_code
from accelerate import Accelerator
from dataset.build import build_dataset
from pathlib import Path
from accelerate.utils import ProjectConfiguration, set_seed
import torch.nn.functional as F
from condition.canny import CannyDetector
from condition.hed import HEDdetector
import numpy as np
from PIL import Image
from condition.lineart import LineArt
import cv2
from transformers import DPTImageProcessor, DPTForDepthEstimation
def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim)
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint
    print(f"image tokenizer is loaded")

    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    latent_size = args.image_size // args.downsample_size
    gpt_model = GPT_models[args.gpt_model](
        block_size=latent_size ** 2,
        cls_token_num=args.cls_token_num,
        model_type=args.gpt_type,
        condition_type=args.condition_type,
        adapter_size=args.adapter_size,
    ).to(device=device, dtype=precision)

    _, file_extension = os.path.splitext(args.gpt_ckpt)
    if file_extension.lower() == '.safetensors':
        from safetensors.torch import load_file
        model_weight = load_file(args.gpt_ckpt)
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
    else:
        checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
        if "model" in checkpoint:  # ddp
            model_weight = checkpoint["model"]
        elif "module" in checkpoint: # deepspeed
            model_weight = checkpoint["module"]
        elif "state_dict" in checkpoint:
            model_weight = checkpoint["state_dict"]
        else:
            raise Exception("please check model weight")
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
        del checkpoint
    print(f"gpt model is loaded")

    if args.compile:
        print(f"compiling the model...")
        gpt_model = torch.compile(
            gpt_model,
            mode="reduce-overhead",
            fullgraph=True
        ) # requires PyTorch 2.0 (optional)
    else:
        print(f"no need to compile model in demo") 
    
    assert os.path.exists(args.t5_path)
    t5_model = T5Embedder(
        device=device, 
        local_cache=True, 
        cache_dir=args.t5_path, 
        dir_or_name=args.t5_model_type,
        torch_dtype=precision,
        model_max_length=args.t5_feature_max_len,
    )
    

    if args.condition_type == 'canny':
        get_control = CannyDetector()
    elif args.condition_type == 'hed':
        get_control = HEDdetector().to(device).eval()
    elif args.condition_type == 'lineart':
        get_control = LineArt()
        get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
        get_control.to(device)
    elif args.condition_type == 'depth':
        processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
        model = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
    with torch.no_grad():
        
        condition_path = args.condition_path
        if args.condition_type == 'seg':
            condition_img = torch.from_numpy(np.array(Image.open(condition_path)))
            condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
        elif args.condition_type == 'canny':
            condition_img = get_control(np.array(Image.open(condition_path)))
            condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
        elif args.condition_type == 'hed':
            condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device))
            condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
        elif args.condition_type == 'lineart':
            condition_img = get_control(torch.from_numpy(np.array(Image.open(condition_path))).permute(2,0,1).unsqueeze(0).to(device).float())
            condition_img = 1 - condition_img
            condition_img = condition_img.repeat(2,3,1,1) * 255
        elif args.condition_type == 'depth':
            images = Image.open(condition_path)
            inputs = processor(images=images, return_tensors="pt", size=(512,512)).to(device)
            outputs = model(**inputs)
            condition_img = outputs.predicted_depth
            condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
            condition_img = (condition_img * 255 / condition_img.max())
        condition_img = condition_img.to(device)
        condition_img = 2*(condition_img/255 - 0.5)
        prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
        prompts = prompts * 2
        caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)

        if not args.no_left_padding:
            print(f"processing left-padding...")    
            # a naive way to implement left-padding
            new_emb_masks = torch.flip(emb_masks, dims=[-1])
            new_caption_embs = []
            for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
                valid_num = int(emb_mask.sum().item())
                print(f'  prompt {idx} token len: {valid_num}')
                new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
                new_caption_embs.append(new_caption_emb)
            new_caption_embs = torch.stack(new_caption_embs)
        else:
            new_caption_embs, new_emb_masks = caption_embs, emb_masks
        c_indices = new_caption_embs * new_emb_masks[:,:, None]
        c_emb_masks = new_emb_masks
        qzshape = [len(c_indices), args.codebook_embed_dim, args.image_H//args.downsample_size, args.image_W//args.downsample_size]
        t1 = time.time()
        index_sample = generate(
            gpt_model, c_indices, (args.image_H//args.downsample_size)*(args.image_W//args.downsample_size),#latent_size ** 2, 
            c_emb_masks, condition=condition_img.to(precision),
            cfg_scale=args.cfg_scale,
            temperature=args.temperature, top_k=args.top_k,
            top_p=args.top_p, sample_logits=True, 
            control_strength=args.control_strength,
            )
        sampling_time = time.time() - t1
        print(f"Full sampling takes about {sampling_time:.2f} seconds.")    
        
        t2 = time.time()
        print(index_sample.shape)
        samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
        decoder_time = time.time() - t2
        print(f"decoder takes about {decoder_time:.2f} seconds.")

        samples = torch.cat((condition_img[0:1], samples), dim=0)
        save_image(samples, f"sample/example/sample_t2i_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
        print(f"image is saved to sample/example/sample_t2i_{args.condition_type}.png")
        print(prompts)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
    parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
    parser.add_argument("--t5-feature-max-len", type=int, default=120)
    parser.add_argument("--t5-feature-dim", type=int, default=2048)
    parser.add_argument("--no-left-padding", action='store_true', default=False)
    parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
    parser.add_argument("--gpt-ckpt", type=str, default=None)
    parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
    parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
    parser.add_argument("--image-H", type=int, default=512)
    parser.add_argument("--image-W", type=int, default=512)
    parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
    parser.add_argument("--cfg-scale", type=float, default=4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")

    parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth', 'canny_base'], default="canny")
    parser.add_argument("--prompt", type=str, default='a high-quality image')
    parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
    parser.add_argument("--adapter-size", type=str, default='small')

    parser.add_argument("--control-strength", type=float, default=1.0)
    args = parser.parse_args()
    main(args)


================================================
FILE: autoregressive/sample/sample_t2i_MR.py
================================================
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from torchvision.utils import save_image

import os
import sys
current_directory = os.getcwd()
sys.path.append(current_directory)
import time
import argparse
from tokenizer.tokenizer_image.vq_model import VQ_models
from language.t5 import T5Embedder
from autoregressive.models.gpt import GPT_models
from autoregressive.models.gpt_t2i import GPT_models
from autoregressive.models.generate import generate
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from dataset.t2i_control import build_t2i_control_code
from accelerate import Accelerator
from dataset.build import build_dataset
from pathlib import Path
from accelerate.utils import ProjectConfiguration, set_seed
import torch.nn.functional as F
from condition.canny import CannyDetector
from condition.hed import HEDdetector
import numpy as np
from PIL import Image
from condition.lineart import LineArt
import cv2
from transformers import DPTImageProcessor, DPTForDepthEstimation
from condition.midas.depth import MidasDetector


def resize_image_to_16_multiple(image_path, condition_type='seg'):
    image = Image.open(image_path)
    width, height = image.size
    
    if condition_type == 'depth':  # The depth model requires a side length that is a multiple of 32
        new_width = (width + 31) // 32 * 32
        new_height = (height + 31) // 32 * 32
    else:
        new_width = (width + 15) // 16 * 16
        new_height = (height + 15) // 16 * 16

    resized_image = image.resize((new_width, new_height))
    return resized_image

def main(args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim)
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint
    print(f"image tokenizer is loaded")

    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    latent_size = args.image_size // args.downsample_size
    gpt_model = GPT_models[args.gpt_model](
        block_size=latent_size ** 2,
        cls_token_num=args.cls_token_num,
        model_type=args.gpt_type,
        condition_type=args.condition_type,
    ).to(device=device, dtype=precision)

    _, file_extension = os.path.splitext(args.gpt_ckpt)
    if file_extension.lower() == '.safetensors':
        from safetensors.torch import load_file
        model_weight = load_file(args.gpt_ckpt)
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
    else:
        checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
        if "model" in checkpoint:  # ddp
            model_weight = checkpoint["model"]
        elif "module" in checkpoint: # deepspeed
            model_weight = checkpoint["module"]
        elif "state_dict" in checkpoint:
            model_weight = checkpoint["state_dict"]
        else:
            raise Exception("please check model weight")
        gpt_model.load_state_dict(model_weight, strict=False)
        gpt_model.eval()
        del checkpoint
    print(f"gpt model is loaded")

    if args.compile:
        print(f"compiling the model...")
        gpt_model = torch.compile(
            gpt_model,
            mode="reduce-overhead",
            fullgraph=True
        ) # requires PyTorch 2.0 (optional)
    else:
        print(f"no need to compile model in demo") 
    
    assert os.path.exists(args.t5_path)
    t5_model = T5Embedder(
        device=device, 
        local_cache=True, 
        cache_dir=args.t5_path, 
        dir_or_name=args.t5_model_type,
        torch_dtype=precision,
        model_max_length=args.t5_feature_max_len,
    )
    

    if args.condition_type == 'canny':
        get_control = CannyDetector()
    elif args.condition_type == 'hed':
        get_control = HEDdetector().to(device).eval()
    elif args.condition_type == 'lineart':
        get_control = LineArt()
        get_control.load_state_dict(torch.load('condition/ckpts/model.pth', map_location=torch.device('cpu')))
        get_control.to(device)
    elif args.condition_type == 'depth':
        processor = DPTImageProcessor.from_pretrained("condition/ckpts/dpt_large")
        model_large = DPTForDepthEstimation.from_pretrained("condition/ckpts/dpt_large").to(device)
        model = MidasDetector(device=device)
    with torch.no_grad():
        
        condition_img = resize_image_to_16_multiple(args.condition_path, args.condition_type)
        W, H = condition_img.size
        print(H,W)
        if args.condition_type == 'seg':
            condition_img = torch.from_numpy(np.array(condition_img))
            condition_img = condition_img.permute(2,0,1).unsqueeze(0).repeat(2,1,1,1)
        elif args.condition_type == 'canny':
            condition_img = get_control(np.array(condition_img))
            condition_img = torch.from_numpy(condition_img[None,None,...]).repeat(2,3,1,1)
        elif args.condition_type == 'hed':
            condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device))
            condition_img = condition_img.unsqueeze(1).repeat(2,3,1,1)
        elif args.condition_type == 'lineart':
            condition_img = get_control(torch.from_numpy(np.array(condition_img)).permute(2,0,1).unsqueeze(0).to(device).float())
            condition_img = condition_img.repeat(2,3,1,1) * 255
        elif args.condition_type == 'depth':
            images = condition_img
            if H == W:
                inputs = processor(images=images, return_tensors="pt", size=(H,W)).to(device)
                outputs = model_large(**inputs)
                condition_img = outputs.predicted_depth
                condition_img = (condition_img * 255 / condition_img.max())
            else:
                condition_img = torch.from_numpy(model(torch.from_numpy(np.array(condition_img)).to(device))).unsqueeze(0)
            condition_img = condition_img.unsqueeze(0).repeat(2,3,1,1)
        condition_img = condition_img.to(device)
        condition_img = 2*(condition_img/255 - 0.5)
        prompts = [args.prompt if args.prompt is not None else "a high-quality image"]
        prompts = prompts * 2
        caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)

        if not args.no_left_padding:
            print(f"processing left-padding...")    
            # a naive way to implement left-padding
            new_emb_masks = torch.flip(emb_masks, dims=[-1])
            new_caption_embs = []
            for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
                valid_num = int(emb_mask.sum().item())
                print(f'  prompt {idx} token len: {valid_num}')
                new_caption_emb = torch.cat([caption_emb[valid_num:],caption_emb[:valid_num]])
                new_caption_embs.append(new_caption_emb)
            new_caption_embs = torch.stack(new_caption_embs)
        else:
            new_caption_embs, new_emb_masks = caption_embs, emb_masks
        c_indices = new_caption_embs * new_emb_masks[:,:, None]
        c_emb_masks = new_emb_masks
        qzshape = [len(c_indices), args.codebook_embed_dim, H//args.downsample_size, W//args.downsample_size]
        t1 = time.time()
        index_sample = generate(
            gpt_model, c_indices, (H//args.downsample_size)*(W//args.downsample_size),#latent_size ** 2, 
            c_emb_masks, condition=condition_img.to(precision),
            cfg_scale=args.cfg_scale,
            temperature=args.temperature, top_k=args.top_k,
            top_p=args.top_p, sample_logits=True, 
            )
        sampling_time = time.time() - t1
        print(f"Full sampling takes about {sampling_time:.2f} seconds.")    
        
        t2 = time.time()
        print(index_sample.shape)
        samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
        decoder_time = time.time() - t2
        print(f"decoder takes about {decoder_time:.2f} seconds.")

        samples = torch.cat((condition_img[0:1], samples), dim=0)
        save_image(samples, f"sample/example/sample_t2i_MR_{args.condition_type}.png", nrow=4, normalize=True, value_range=(-1, 1))
        print(f"image is saved to sample/example/sample_t2i_MR_{args.condition_type}.png")
        print(prompts)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--t5-path", type=str, default='checkpoints/t5-ckpt')
    parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
    parser.add_argument("--t5-feature-max-len", type=int, default=120)
    parser.add_argument("--t5-feature-dim", type=int, default=2048)
    parser.add_argument("--no-left-padding", action='store_true', default=False)
    parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
    parser.add_argument("--gpt-ckpt", type=str, default=None)
    parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
    parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--image-size", type=int, choices=[256, 320, 384, 400, 448, 512, 576, 640, 704, 768], default=768)
    parser.add_argument("--image-H", type=int, default=512)
    parser.add_argument("--image-W", type=int, default=512)
    parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
    parser.add_argument("--cfg-scale", type=float, default=4)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--top-k", type=int, default=2000, help="top-k value to sample with")
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")

    parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--condition-type", type=str, choices=['seg', 'canny', 'hed', 'lineart', 'depth'], default="canny")
    parser.add_argument("--prompt", type=str, default='a high-quality image')
    parser.add_argument("--condition-path", type=str, default='condition/example/t2i/multigen/landscape.png')
    args = parser.parse_args()
    main(args)


================================================
FILE: autoregressive/sample/sample_t2i_ddp.py
================================================
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
import torch.nn.functional as F
import torch.distributed as dist

import os
import math
import json
import argparse
import pandas as pd
from tqdm import tqdm
from PIL import Image

from tokenizer.tokenizer_image.vq_model import VQ_models
from language.t5 import T5Embedder
from autoregressive.models.gpt import GPT_models
from autoregressive.models.generate import generate
os.environ["TOKENIZERS_PARALLELISM"] = "false"



def main(args):
    # Setup PyTorch:
    assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
    torch.set_grad_enabled(False)

    # Setup DDP:
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    # create and load model
    vq_model = VQ_models[args.vq_model](
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim)
    vq_model.to(device)
    vq_model.eval()
    checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
    vq_model.load_state_dict(checkpoint["model"])
    del checkpoint
    print(f"image tokenizer is loaded")

    # create and load gpt model
    precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
    latent_size = args.image_size // args.downsample_size
    gpt_model = GPT_models[args.gpt_model](
        block_size=latent_size ** 2,
        cls_token_num=args.cls_token_num,
        model_type=args.gpt_type,
    ).to(device=device, dtype=precision)

    checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
 
    if "model" in checkpoint:  # ddp
        model_weight = checkpoint["model"]
    elif "module" in checkpoint: # deepspeed
        model_weight = checkpoint["module"]
    elif "state_dict" in checkpoint:
        model_weight = checkpoint["state_dict"]
    else:
        raise Exception("please check model weight")
    gpt_model.load_state_dict(model_weight, strict=False)
    gpt_model.eval()
    del checkpoint
    print(f"gpt model is loaded")

    if args.compile:
        print(f"compiling the model...")
        gpt_model = torch.compile(
            gpt_model,
            mode="reduce-overhead",
            fullgraph=True
        ) # requires PyTorch 2.0 (optional)
    else:
        print(f"no need to compile model in demo") 
    
    assert os.path.exists(args.t5_path)
    t5_model = T5Embedder(
        device=device, 
        local_cache=True, 
        cache_dir=args.t5_path, 
        dir_or_name=args.t5_model_type,
        torch_dtype=precision,
        model_max_length=args.t5_feature_max_len,
    )
    print(f"t5 model is loaded")

    # Create folder to save samples:
    model_string_name = args.gpt_model.replace("/", "-")
    ckpt_string_name = os.path.basename(args.gpt_ckpt).replace(".pth", "").replace(".pt", "")
    prompt_name = args.prompt_csv.split('/')[-1].split('.')[0].lower()
    folder_name = f"{model_string_name}-{ckpt_string_name}-{prompt_name}-size-{args.image_size}-size-{args.image_size}-{args.vq_model}-" \
                  f"topk-{args.top_k}-topp-{args.top_p}-temperature-{args.temperature}-" \
                  f"cfg-{args.cfg_scale}-seed-{args.global_seed}"
    sample_folder_dir = f"{args.sample_dir}/{folder_name}"
    if rank == 0:
        os.makedirs(f"{sample_folder_dir}/images", exist_ok=True)
        print(f"Saving .png samples at {sample_folder_dir}/images")
    dist.barrier()

    df = pd.read_csv(args.prompt_csv, delimiter='\t')
    prompt_list = df['Prompt'].tolist()

    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
    n = args.per_proc_batch_size
    global_batch_size = n * dist.get_world_size()
    num_fid_samples = min(args.num_fid_samples, len(prompt_list))
    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
    total_samples = int(math.ceil(num_fid_samples / global_batch_size) * global_batch_size)
    if rank == 0:
        print(f"Total number of images that will be sampled: {total_samples}")
    assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
    samples_needed_this_gpu = int(total_samples // dist.get_world_size())
    assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
    iterations = int(samples_needed_this_gpu // n)
    pbar = range(iterations)
    pbar = tqdm(pbar) if rank == 0 else pbar
    total = 0
    for _ in pbar:
        # Select text prompt
        prompt_batch = []
        for i in range(n):
            index = i * dist.get_world_size() + rank + total
            prompt_batch.append(prompt_list[index] if index < len(prompt_list) else "a cute dog")
              
        # Sample inputs:
        caption_embs, emb_masks = t5_model.get_text_embeddings(prompt_batch)
        
        if not args.no_left_padding:
            new_emb_masks = torch.flip(emb_masks, dims=[-1])
            new_caption_embs = []
            for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
                valid_num = int(emb_mask.sum().item())
                # prompt_cur = prompt_batch[idx]
                # print(f'  prompt {idx} token len: {valid_num} : {prompt_cur}')
                new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]])
                new_caption_embs.append(new_caption_emb)
            new_caption_embs = torch.stack(new_caption_embs)

        else:
            new_caption_embs, new_emb_masks = caption_embs, emb_masks

        c_indices = new_caption_embs * new_emb_masks[:,:, None]
        c_emb_masks = new_emb_masks

        qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
        index_sample = generate(
            gpt_model, c_indices, latent_size ** 2, 
            c_emb_masks,
            cfg_scale=args.cfg_scale,
            temperature=args.temperature, top_k=args.top_k,
            top_p=args.top_p, sample_logits=True, 
            )
        
        samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
        samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
        
        # Save samples to disk as individual .png files
        for i, sample in enumerate(samples):
            index = i * dist.get_world_size() + rank + total
            Image.fromarray(sample).save(f"{sample_folder_dir}/images/{index:06d}.png")
        total += global_batch_size

    # Make sure all processes have finished saving their samples before attempting to convert to .npz
    dist.barrier()
    if rank == 0:
        # Save infer result in a jsonl file
        json_items = []
        for idx, prompt in enumerate(prompt_list):
            image_path = os.path.join(sample_folder_dir, "images", f"{idx:06d}.png")
            json_items.append({"text": prompt, "image_path": image_path})
        res_jsonl_path = os.path.join(sample_folder_dir, "result.jsonl")
        print(f"Save jsonl to {res_jsonl_path}...")
        with open(res_jsonl_path, "w") as f:
            for item in json_items:
                f.write(json.dumps(item) + "\n")

        # Save captions to txt
        caption_path = os.path.join(sample_folder_dir, "captions.txt")
        print(f"Save captions to {caption_path}...")
        with open(caption_path, "w") as f:
            for item in prompt_list:
                f.write(f"{item}\n")
        print("Done.")
    
    dist.barrier()
    dist.destroy_process_group()



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt-csv", type=str, default='evaluations/t2i/PartiPrompts.tsv')
    parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt')
    parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
    parser.add_argument("--t5-feature-max-len", type=int, default=120)
    parser.add_argument("--t5-feature-dim", type=int, default=2048)
    parser.add_argument("--no-left-padding", action='store_true', default=False)
    parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
    parser.add_argument("--gpt-ckpt", type=str, default=None)
    parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")  
    parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
    parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--compile", action='store_true', default=False)
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512)
    parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=7.5)
    parser.add_argument("--sample-dir", type=str, default="samples_parti", help="samples_coco or samples_parti")
    parser.add_argument("--per-proc-batch-size", type=int, default=32)
    parser.add_argument("--num-fid-samples", type=int, default=30000)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with")
    parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
    parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
    args = parser.parse_args()
    main(args)


================================================
FILE: autoregressive/serve/README.md
================================================
## serving by vLLM

### Install
```
pip install vllm==0.4.1
```

### Comparison (A100)

Method | params | baseline(s) | vllm(s) | speed-up ratio 
--- |:---:|:---:|:---:|:---:
[GPT-B](./fake_json/GPT-B.json)    | 111M | 7.80    | 2.39      |  326 %
[GPT-L](./fake_json/GPT-L.json)    | 343M | 13.72   | 3.48      |  380 %
[GPT-XL](./fake_json/GPT-XL.json)  | 775M | 19.76   | 4.84      |  408 %
[GPT-XXL](./fake_json/GPT-XXL.json)| 1.4B | 26.38   | 6.36      |  414 %
[GPT-3B](./fake_json/GPT-3B.json)  | 3.1B | 14.73   | 6.26      |  235 %

```
### GPT-B
# 7.80 seconds
python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384

# 2.39 seconds
python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384


### GPT-L
# 13.72 seconds
python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384

# 3.48 seconds
python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384


### GPT-XL
# 19.76 seconds
python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384

# 4.84 seconds
python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384


### GPT-XXL
# 26.38 seconds
python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384

# 6.36 seconds
python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384


### GPT-3B
# 14.73 seconds
python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384

# 6.26 seconds
python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384

```


================================================
FILE: autoregressive/serve/fake_json/GPT-3B.json
================================================
{
  "_name_or_path": "facebook/opt-125m",
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": true,
  "dropout": 0.1,
  "eos_token_id": 2,
  "ffn_dim": 3072,
  "hidden_size": 3584,
  "init_std": 0.02,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 32,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.21.0.dev0",
  "use_cache": true,
  "vocab_size": 16384,
  "word_embed_proj_dim": 768
}


================================================
FILE: autoregressive/serve/fake_json/GPT-B.json
================================================
{
  "_name_or_path": "facebook/opt-125m",
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": true,
  "dropout": 0.1,
  "eos_token_id": 2,
  "ffn_dim": 3072,
  "hidden_size": 768,
  "init_std": 0.02,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.21.0.dev0",
  "use_cache": true,
  "vocab_size": 16384,
  "word_embed_proj_dim": 768
}


================================================
FILE: autoregressive/serve/fake_json/GPT-L.json
================================================
{
  "_name_or_path": "facebook/opt-125m",
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": true,
  "dropout": 0.1,
  "eos_token_id": 2,
  "ffn_dim": 3072,
  "hidden_size": 1024,
  "init_std": 0.02,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.21.0.dev0",
  "use_cache": true,
  "vocab_size": 16384,
  "word_embed_proj_dim": 768
}


================================================
FILE: autoregressive/serve/fake_json/GPT-XL.json
================================================
{
  "_name_or_path": "facebook/opt-125m",
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": true,
  "dropout": 0.1,
  "eos_token_id": 2,
  "ffn_dim": 3072,
  "hidden_size": 1280,
  "init_std": 0.02,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 20,
  "num_hidden_layers": 36,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.21.0.dev0",
  "use_cache": true,
  "vocab_size": 16384,
  "word_embed_proj_dim": 768
}


================================================
FILE: autoregressive/serve/fake_json/GPT-XXL.json
================================================
{
  "_name_or_path": "facebook/opt-125m",
  "activation_dropout": 0.0,
  "activation_function": "relu",
  "architectures": [
    "OPTForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 2,
  "do_layer_norm_before": true,
  "dropout": 0.1,
  "eos_token_id": 2,
  "ffn_dim": 3072,
  "hidden_size": 1536,
  "init_std": 0.02,
  "layerdrop": 0.0,
  "max_position_embeddings": 2048,
  "model_type": "opt",
  "num_attention_heads": 24,
  "num_hidden_layers": 48,
  "pad_token_id": 1,
  "prefix": "</s>",
  "torch_dtype": "bfloat16",
  "transformers_version": "4.21.0.dev0",
  "use_cache": true,
  "vocab_size": 16384,
  "word_embed_proj_dim": 768
}


================================================
FILE: autoregressive/serve/gpt_model.py
================================================
from dataclasses import dataclass
from typing import Optional, List

import torch
import torch.nn as nn

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from vllm.attention import AttentionMetadata
from vllm.attention import Attention as pagedAttention

from vllm.model_executor.layers.logits_processor import LogitsProcessor
from autoregressive.serve.sampler import Sampler

def find_multiple(n: int, k: int):
    if n % k == 0:
        return n
    return n + k - (n % k)

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layer: int = 32
    n_head: int = 32
    n_kv_head: Optional[int] = None
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    rope_base: float = 10000
    norm_eps: float = 1e-5
    initializer_range: float = 0.02
    
    num_classes: int = 1000
    class_dropout_prob: float = 0.1
    model_type: str = 'c2i'
    cfg_scale: float = 4.0

    vocab_size: int = 16384
    cls_token_num: int = 1
    block_size: int = 256
    max_batch_size: int = 32
    max_seq_len: int = 2048


#################################################################################
#                      Embedding Layers for Class Labels                        #
#################################################################################
class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    # def token_drop(self, labels, force_drop_ids=None):
    #     """
    #     Drops labels to enable classifier-free guidance.
    #     """
    #     if force_drop_ids is None:
    #         drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
    #     else:
    #         drop_ids = force_drop_ids == 1
    #     labels = torch.where(drop_ids, self.num_classes, labels)
    #     return labels

    # def forward(self, labels, train, force_drop_ids=None):
    def forward(self, labels):
        # use_dropout = self.dropout_prob > 0
        # if (train and use_dropout) or (force_drop_ids is not None):
        #     labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings


#################################################################################
#                                  GPT Model                                    #
#################################################################################
# class RMSNorm(torch.nn.Module):
#     def __init__(self, dim: int, eps: float = 1e-5):
#         super().__init__()
#         self.eps = eps
#         self.weight = nn.Parameter(torch.ones(dim))

#     def _norm(self, x):
#         return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)

#     def forward(self, x):
#         output = self._norm(x.float()).type_as(x)
#         return output * self.weight


class FeedForward(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        hidden_dim = 4 * config.dim
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if config.ffn_dim_multiplier is not None:
            hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
        hidden_dim = find_multiple(hidden_dim, config.multiple_of)

        # self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
        # self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
        self.w_merged = nn.Linear(config.dim, hidden_dim * 2, bias=False)
        self.act_fn = SiluAndMul()

        self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
        # self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)

    # def forward(self, x):
    #     return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

    def forward(self, x):
        x = self.w_merged(x)
        x = self.act_fn(x)
        x = self.w2(x)
        # return self.ffn_dropout(x)
        return x


class Attention(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        assert config.dim % config.n_head == 0
        self.dim = config.dim
        self.head_dim = config.dim // config.n_head
        self.n_head = config.n_head
        self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
        total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim

        # key, query, value projections for all heads, but in a batch
        self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
        self.wo = nn.Linear(config.dim, config.dim, bias=False)

        # pagedAttention
        if config.dim // config.n_head == 100:
            self.attn = None  # for this case, we need to overwrite the attn in AttentionMonkeyPatch
        else:
            self.attn = pagedAttention(self.n_head, self.head_dim, self.head_dim**-0.5, num_kv_heads=self.n_kv_head)

        # 2d rotary pos embedding
        grid_size = int(config.block_size ** 0.5)
        assert grid_size * grid_size == config.block_size
        freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num)
        self.register_buffer('freqs_cis', freqs_cis)


    def forward(
        self, 
        x: torch.Tensor,
        positions: torch.Tensor, 
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ):  
        kv_size = self.n_kv_head * self.head_dim
        xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

        xq = xq.view(*xq.shape[:-1], 1, self.n_head, self.head_dim)
        xk = xk.view(*xk.shape[:-1], 1, self.n_kv_head, self.head_dim)
        freqs_cis = self.freqs_cis[positions].unsqueeze(1)        
        xq = apply_rotary_emb_bs(xq, freqs_cis)
        xk = apply_rotary_emb_bs(xk, freqs_cis)
        xq = xq.flatten(1)
        xk = xk.flatten(1)

        output = self.attn(xq, xk, xv, kv_cache, attn_metadata)
        output = self.wo(output)
        
        return output


class AttentionMonkeyPatch(Attention):
    """
    Note:
    In vllm, PagedAttention supports head sizes [64, 80, 96, 112, 128, 256].
    However, LlamaGen-3B model has head size 100 (for some historical reasons).
    Here we hack Attnetion to enable vllm support head size 100.
    """
    def __init__(self, config: ModelArgs):
        super().__init__(config)
        # overwrite PagedAttention
        # hard-coded 112 for LlamaGen-3B model
        self.attn = pagedAttention(self.n_head, 112, 100**-0.5, num_kv_heads=self.n_kv_head)

    def forward(
        self,
        x: torch.Tensor,
        positions: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ):
        kv_size = self.n_kv_head * self.head_dim
        xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)

        xq = xq.view(*xq.shape[:-1], 1, self.n_head, self.head_dim)
        xk = xk.view(*xk.shape[:-1], 1, self.n_kv_head, self.head_dim)
        freqs_cis = self.freqs_cis[positions].unsqueeze(1)
        xq = apply_rotary_emb_bs(xq, freqs_cis)
        xk = apply_rotary_emb_bs(xk, freqs_cis)
        xq = xq.flatten(1)
        xk = xk.flatten(1)
        ############ padding to 112 to make vllm happy ############
        zero_pad = torch.zeros(xq.shape[0], self.n_head, 112 - 100, device=xq.device, dtype=xq.dtype)
        xq = xq.reshape(xq.shape[0], self.n_head, self.head_dim)
        xk = xk.reshape(xk.shape[0], self.n_kv_head, self.head_dim)
        xv = xv.reshape(xv.shape[0], self.n_kv_head, self.head_dim)
        xq = torch.concat([xq, zero_pad], dim=-1).flatten(1)
        xk = torch.concat([xk, zero_pad], dim=-1).flatten(1)
        xv = torch.concat([xv, zero_pad], dim=-1).flatten(1)

        output = self.attn(xq, xk, xv, kv_cache, attn_metadata)
        ############ de-padding to 100 ############
        output = output.reshape(output.shape[0], self.n_head, 112)
        output = output[..., :100].flatten(1)

        output = self.wo(output)

        return output


class TransformerBlock(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        if config.dim // config.n_head == 100:
            self.attention = AttentionMonkeyPatch(config)
        else:
            self.attention = Attention(config)
        self.feed_forward = FeedForward(config)
        self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)

    def forward(self, x: torch.Tensor, positions: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata):
        h = x + self.attention(self.attention_norm(x), positions, kv_cache, attn_metadata)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
        

class Transformer(nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size
        self.n_layer = config.n_layer
        self.block_size = config.block_size
        self.num_classes = config.num_classes
        self.model_type = config.model_type
        self.cls_token_num = config.cls_token_num
        self.cfg_scale = config.cfg_scale
        if self.model_type == 'c2i':
            self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
        else:
            raise Exception("vllm only supports c2i now, please check model type")
        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)

        self.layers = torch.nn.ModuleList()
        for layer_id in range(config.n_layer):
            self.layers.append(TransformerBlock(config))

        # output layer
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.sampler = Sampler(config.cfg_scale)

    def forward(
        self, 
        input_ids: torch.Tensor=None,
        positions: torch.Tensor=None,
        kv_caches: List[torch.Tensor]=None,
        attn_metadata: AttentionMetadata=None,
    ):
        # if positions.max() == 0: # prefill in inference
        #     token_embeddings = self.cls_embedding(input_ids)
        # else: # decode_n_tokens(kv cache) in inference
        #     token_embeddings = self.tok_embeddings(input_ids)
        cond_ids = torch.clamp(input_ids, max=self.num_classes)
        token_embeddings = self.cls_embedding(cond_ids) * (positions.max() == 0) + \
            self.tok_embeddings(input_ids) * (positions.max() != 0)

        hh = token_embeddings
        # transformer blocks
        for layer_id, layer in enumerate(self.layers):
            hh = layer(hh, positions, kv_caches[layer_id], attn_metadata)
        
        # output layers
        hh = self.norm(hh)
        return hh

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.output.weight, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens
        

    def custom_load_state_dict(self, model_weights):
        model_weights = model_weights.copy()
        for layer_id in range(len(self.layers)):
            branch1 = f'layers.{layer_id}.feed_forward.w1.weight'
            branch3 = f'layers.{layer_id}.feed_forward.w3.weight'
            branch_merged = f'layers.{layer_id}.feed_forward.w_merged.weight'
            model_weights[branch_merged] = torch.cat(
                [model_weights[branch1], model_weights[branch3]], dim=0
            )
            model_weights.pop(branch1)
            model_weights.pop(branch3)

        if 'freqs_cis' in model_weights:
            model_weights.pop('freqs_cis')
        
        self.load_state_dict(model_weights, strict=False)



#################################################################################
#                      Rotary Positional Embedding Functions                    #
#################################################################################
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
    freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
    return cond_cache 


def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
    # split the dimension into half, one for x and one for y
    half_dim = n_elem // 2
    freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
    t = torch.arange(grid_size, device=freqs.device)
    freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
    freqs_grid = torch.concat([
        freqs[:, None, :].expand(-1, grid_size, -1),
        freqs[None, :, :].expand(grid_size, -1, -1),
    ], dim=-1)  # (grid_size, grid_size, head_dim // 2)
    cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
    cache = cache_grid.flatten(0, 1)
    cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
    return cond_cache 


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
    # x: (bs, seq_len, n_head, head_dim)
    # freqs_cis (seq_len, head_dim // 2, 2)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
    x_out2 = torch.stack([
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)


def apply_rotary_emb_bs(x: torch.Tensor, freqs_cis: torch.Tensor):
    # x: (bs, seq_len, n_head, head_dim)
    # freqs_cis (seq_len, head_dim // 2, 2)
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
    freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # (bs, seq_len, 1, head_dim//2, 2)
    x_out2 = torch.stack([
            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ], dim=-1)
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)


#################################################################################
#                                GPT Configs                                    #
#################################################################################
### text-conditional
def GPT_7B(**kwargs):
    return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B

def GPT_3B(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B

def GPT_1B(**kwargs):
    return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B

### class-conditional
def GPT_XXXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B

def GPT_XXL(**kwargs):
    return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B

def GPT_XL(**kwargs):
    return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M

def GPT_L(**kwargs):
    return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M

def GPT_B(**kwargs):
    return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
        

GPT_models = {
    'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
    'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 
}

================================================
FILE: autoregressive/serve/gpu_executor.py
================================================
from typing import Dict, List, Set, Tuple, Optional, Set
import argparse

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         SpeculativeConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
                        make_async)

logger = init_logger(__name__)


class GPUExecutor(ExecutorBase):
    def __init__(
        self,
        args: argparse.ArgumentParser,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
    ) -> None:
        self.args = args
        self.model_config = model_config
        self.cache_config = cache_config
        self.lora_config = lora_config
        self.load_config = load_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.vision_language_config = vision_language_config
        self.speculative_config = speculative_config

        self._init_executor()

    def _init_executor(self) -> None:
        """Initialize the worker and load the model.

        If speculative decoding is enabled, we instead create the speculative
        worker.
        """
        if self.speculative_config is None:
            self._init_non_spec_worker()
        else:
            self._init_spec_worker()

    def _init_non_spec_worker(self):
        # Lazy import the Worker to avoid importing torch.cuda/xformers
        # before CUDA_VISIBLE_DEVICES is set in the Worker
        # from vllm.worker.worker import Worker
        from autoregressive.serve.worker import Worker

        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        self.driver_worker = Worker(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=True,
        )
        self.driver_worker.init_device()
        self.driver_worker.load_model(self.args)

    def _init_spec_worker(self):
        """Initialize a SpecDecodeWorker, using a draft model for proposals.
        """
        assert self.speculative_config is not None

        from vllm.spec_decode.multi_step_worker import MultiStepWorker
        from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
        from vllm.worker.worker import Worker

        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())

        target_worker = Worker(
            model_config=self.model_config,
            parallel_config=self.parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=True,
        )

        draft_worker = MultiStepWorker(
            model_config=self.speculative_config.draft_model_config,
            parallel_config=self.speculative_config.draft_parallel_config,
            scheduler_config=self.scheduler_config,
            device_config=self.device_config,
            cache_config=self.cache_config,
            load_config=self.load_config,
            local_rank=0,
            rank=0,
            distributed_init_method=distributed_init_method,
            lora_config=self.lora_config,
            vision_language_config=self.vision_language_config,
            is_driver_worker=True,
        )

        spec_decode_worker = SpecDecodeWorker.from_workers(
            proposer_worker=draft_worker, scorer_worker=target_worker)

        assert self.parallel_config.world_size == 1, (
            "GPUExecutor only supports single GPU.")

        self.driver_worker = spec_decode_worker

        # Load model handled in spec decode worker.
        self.driver_worker.init_device()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available KV blocks by invoking the
        underlying worker.
        """
        return self.driver_worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
        """Initialize the KV cache by invoking the underlying worker.
        """
        # NOTE: This is logged in the executor because there can be >1 worker
        # with other executors. We could log in the engine level, but work
        # remains to abstract away the device for non-GPU configurations.
        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
                    f"# CPU blocks: {num_cpu_blocks}")

        self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def execute_model(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
        num_lookahead_slots: int,
    ) -> List[SamplerOutput]:
        output = self.driver_worker.execute_model(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
            num_lookahead_slots=num_lookahead_slots,
        )
        return output

    def add_lora(self, lora_request: LoRARequest) -> bool:
        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        assert lora_id > 0, "lora_id must be greater than 0."
        return self.driver_worker.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.driver_worker.list_loras()

    def check_health(self) -> None:
        # GPUExecutor will always be healthy as long as
        # it's running.
        return


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):

    async def execute_model_async(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> SamplerOutput:
        output = await make_async(self.driver_worker.execute_model)(
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy)
        return output

================================================
FILE: autoregressive/serve/llm.py
================================================
# Modified from:
#   vLLM:    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
from typing import List, Optional, Union
import argparse

import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.engine.arg_utils import EngineArgs
# from vllm.engine.llm_engine import LLMEngine
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter

from autoregressive.serve.llm_engine import LLMEngine


class LLM:
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
    serving, use the `AsyncLLMEngine` class instead.
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.

    Args:
        model: The name or path of a HuggingFace Transformers model.
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
        quantization: The method used to quantize the model weights. Currently,
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
        disable_custom_all_reduce: See ParallelConfig
    """

    def __init__(
        self,
        args: argparse.ArgumentParser,
        model: str,
        tokenizer: Optional[str] = None,
        tokenizer_mode: str = "auto",
        skip_tokenizer_init: bool = False,
        trust_remote_code: bool = False,
        tensor_parallel_size: int = 1,
        dtype: str = "auto",
        quantization: Optional[str] = None,
        revision: Optional[str] = None,
        tokenizer_revision: Optional[str] = None,
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
        enforce_eager: bool = False,
        max_context_len_to_capture: int = 8192,
        disable_custom_all_reduce: bool = False,
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
        engine_args = EngineArgs(
            model=model,
            tokenizer=tokenizer,
            tokenizer_mode=tokenizer_mode,
            skip_tokenizer_init=skip_tokenizer_init,
            trust_remote_code=trust_remote_code,
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
            quantization=quantization,
            revision=revision,
            tokenizer_revision=tokenizer_revision,
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
            disable_custom_all_reduce=disable_custom_all_reduce,
            **kwargs,
        )
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS, args=args)
        self.request_counter = Counter()

    def get_tokenizer(
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
        return self.llm_engine.tokenizer.tokenizer

    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
        self.llm_engine.tokenizer.tokenizer = tokenizer

    def generate(
        self,
        prompts: Optional[Union[str, List[str]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[RequestOutput]:
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            multi_modal_data: Multi modal data.

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
        if self.llm_engine.model_config.skip_tokenizer_init \
            and prompts is not None:
            raise ValueError("prompts must be None if skip_tokenizer_init "
                             "is True")
        if isinstance(prompts, str):
            # Convert a single prompt to a list.
            prompts = [prompts]
        if (prompts is not None and prompt_token_ids is not None
                and len(prompts) != len(prompt_token_ids)):
            raise ValueError("The lengths of prompts and prompt_token_ids "
                             "must be the same.")

        if prompts is not None:
            num_requests = len(prompts)
        else:
            assert prompt_token_ids is not None
            num_requests = len(prompt_token_ids)

        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

        elif isinstance(sampling_params,
                        list) and len(sampling_params) != num_requests:
            raise ValueError("The lengths of prompts and sampling_params "
                             "must be the same.")
        if multi_modal_data:
            multi_modal_data.data = multi_modal_data.data.to(torch.float16)

        # Add requests to the engine.
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
            token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
            self._add_request(
                prompt,
                sampling_params[i]
                if isinstance(sampling_params, list) else sampling_params,
                token_ids,
                lora_request=lora_request,
                # Get ith image while maintaining the batch dim.
                multi_modal_data=MultiModalData(
                    type=multi_modal_data.type,
                    data=multi_modal_data.data[i].unsqueeze(0))
                if multi_modal_data else None,
            )
        return self._run_engine(use_tqdm)

    def _add_request(
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> None:
        request_id = str(next(self.request_counter))
        self.llm_engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids,
                                    lora_request=lora_request,
                                    multi_modal_data=multi_modal_data)


    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_engine.get_num_unfinished_requests()
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
                postfix=f"Generation Speed: {0:.2f} toks/s",
            )
        # Run the engine.
        outputs: List[RequestOutput] = []
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        total_toks += (sum(
                            len(stp.token_ids) for stp in output.outputs))
                        spd = total_toks / pbar.format_dict["elapsed"]
                        pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
        return outputs


================================================
FILE: autoregressive/serve/llm_engine.py
================================================
# Modified from:
#   vLLM:    https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
import time
from typing import Iterable, List, Optional, Type, Union
import argparse

from transformers import GenerationConfig, PreTrainedTokenizer

import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
                         LoRAConfig, ModelConfig, ParallelConfig,
                         SchedulerConfig, SpeculativeConfig,
                         VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
                           SequenceGroup)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
                                                     get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
from vllm.utils import Counter

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig):
    try:
        return GenerationConfig.from_pretrained(
            model_config.model,
            revision=model_config.revision,
        ).to_diff_dict()
    except OSError:
        # Not found.
        return {}


class LLMEngine:
    """An LLM engine that receives requests and generates texts.

    This is the main class for the vLLM engine. It receives requests
    from clients and generates texts from the LLM. It includes a tokenizer, a
    language model (possibly distributed across multiple GPUs), and GPU memory
    space allocated for intermediate states (aka KV cache). This class utilizes
    iteration-level scheduling and efficient memory management to maximize the
    serving throughput.

    The `LLM` class wraps this class for offline batched inference and the
    `AsyncLLMEngine` class wraps this class for online serving.

    NOTE: The config arguments are derived from the `EngineArgs` class. For the
    comprehensive list of arguments, see `EngineArgs`.

    Args:
        model_config: The configuration related to the LLM model.
        cache_config: The configuration related to the KV cache memory
            management.
        parallel_config: The configuration related to distributed execution.
        scheduler_config: The configuration related to the request scheduler.
        device_config: The configuration related to the device.
        lora_config (Optional): The configuration related to serving multi-LoRA.
        vision_language_config (Optional): The configuration related to vision
            language models.
        speculative_config (Optional): The configuration related to speculative
            decoding.
        executor_class: The model executor class for managing distributed
            execution.
        log_stats: Whether to log statistics.
        usage_context: Specified entry point, used for usage info collection
    """

    def __init__(
        self,
        args: argparse.ArgumentParser,
        model_config: ModelConfig,
        cache_config: CacheConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        device_config: DeviceConfig,
        load_config: LoadConfig,
        lora_config: Optional[LoRAConfig],
        vision_language_config: Optional[VisionLanguageConfig],
        speculative_config: Optional[SpeculativeConfig],
        decoding_config: Optional[DecodingConfig],
        executor_class: Type[ExecutorBase],
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
    ) -> None:
        logger.info(
            f"Initializing an LLM engine (v{vllm.__version__}) with config: "
            f"model={model_config.model!r}, "
            f"speculative_config={speculative_config!r}, "
            f"tokenizer={model_config.tokenizer!r}, "
            f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
            f"tokenizer_mode={model_config.tokenizer_mode}, "
            f"revision={model_config.revision}, "
            f"tokenizer_revision={model_config.tokenizer_revision}, "
            f"trust_remote_code={model_config.trust_remote_code}, "
            f"dtype={model_config.dtype}, "
            f"max_seq_len={model_config.max_model_len}, "
            f"download_dir={load_config.download_dir!r}, "
            f"load_format={load_config.load_format}, "
            f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
            f"disable_custom_all_reduce="
            f"{parallel_config.disable_custom_all_reduce}, "
            f"quantization={model_config.quantization}, "
            f"enforce_eager={model_config.enforce_eager}, "
            f"kv_cache_dtype={cache_config.cache_dtype}, "
            f"quantization_param_path={model_config.quantization_param_path}, "
            f"device_config={device_config.device}, "
            f"decoding_config={decoding_config!r}, "
            f"seed={model_config.seed})")
        # TODO(woosuk): Print more configs in debug mode.

        self.model_config = model_config
        self.cache_config = cache_config
        self.lora_config = lora_config
        self.vision_language_config = vision_language_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
        self.device_config = device_config
        self.speculative_config = speculative_config
        self.load_config = load_config
        self.decoding_config = decoding_config or DecodingConfig()
        self.log_stats = log_stats

        if not self.model_config.skip_tokenizer_init:
            self.tokenizer: BaseTokenizerGroup
            self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)
        else:
            self.detokenizer = None
            self.tokenizer = None

        self.seq_counter = Counter()
        self.generation_config_fields = _load_generation_config_dict(
            model_config)

        self.model_executor = executor_class(
            args=args,
            model_config=model_config,
            cache_config=cache_config,
            parallel_config=parallel_config,
            scheduler_config=scheduler_config,
            device_config=device_config,
            lora_config=lora_config,
            vision_language_config=vision_language_config,
            speculative_config=speculative_config,
            load_config=load_config,
        )

        self._initialize_kv_caches()

        # If usage stat is enabled, collect relevant info.
        if is_usage_stats_enabled():
            from vllm.model_executor.model_loader import (
                get_architecture_class_name)
            usage_message.report_usage(
                get_architecture_class_name(model_config),
                usage_context,
                extra_kvs={
                    # Common configuration
                    "dtype":
                    str(model_config.dtype),
                    "tensor_parallel_size":
                    parallel_config.tensor_parallel_size,
                    "block_size":
                    cache_config.block_size,
                    "gpu_memory_utilization":
                    cache_config.gpu_memory_utilization,

                    # Quantization
                    "quantization":
                    model_config.quantization,
                    "kv_cache_dtype":
                    cache_config.cache_dtype,

                    # Feature flags
                    "enable_lora":
                    bool(lora_config),
                    "enable_prefix_caching":
                    cache_config.enable_prefix_caching,
                    "enforce_eager":
                    model_config.enforce_eager,
                    "disable_custom_all_reduce":
                    parallel_config.disable_custom_all_reduce,
                })

        if self.tokenizer:
            # Ping the tokenizer to ensure liveness if it runs in a
            # different process.
            self.tokenizer.ping()

        # Create the scheduler.
        # NOTE: the cache_config here have been updated with the numbers of
        # GPU and CPU blocks, which are profiled in the distributed executor.
        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)

        # Metric Logging.
        if self.log_stats:
            self.stat_logger = StatLogger(
                local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                labels=dict(model_name=model_config.model))
            self.stat_logger.info("cache_config", self.cache_config)

        # Create sequence output processor, e.g. for beam search or
        # speculative decoding.
        self.output_processor = (
            SequenceGroupOutputProcessor.create_output_processor(
                self.scheduler_config,
                self.detokenizer,
                self.scheduler,
                self.seq_counter,
                self.get_tokenizer_for_seq,
                stop_checker=StopChecker(
                    self.scheduler_config.max_model_len,
                    self.get_tokenizer_for_seq,
                ),
            ))

    def _initialize_kv_caches(self) -> None:
        """Initialize the KV cache in the worker(s).

        The workers will determine the number of blocks in both the GPU cache
        and the swap CPU cache.
        """
        num_gpu_blocks, num_cpu_blocks = (
            self.model_executor.determine_num_available_blocks())

        if self.cache_config.num_gpu_blocks_override is not None:
            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
            logger.info(f"Overriding {num_gpu_blocks=} with "
                        f"{num_gpu_blocks_override=}")
            num_gpu_blocks = num_gpu_blocks_override

        self.cache_config.num_gpu_blocks = num_gpu_blocks
        self.cache_config.num_cpu_blocks = num_cpu_blocks

        self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        args: argparse.ArgumentParser = None,
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_config = engine_args.create_engine_config()

        # Initialize the cluster and specify the executor class.
        if engine_config.device_config.device_type == "neuron":
            from vllm.executor.neuron_executor import NeuronExecutor
            executor_class = NeuronExecutor
        elif engine_config.device_config.device_type == "cpu":
            from vllm.executor.cpu_executor import CPUExecutor
            executor_class = CPUExecutor
        elif engine_config.parallel_config.worker_use_ray:
            initialize_ray_cluster(engine_config.parallel_config)
            from vllm.executor.ray_gpu_executor import RayGPUExecutor
            executor_class = RayGPUExecutor
        else:
            assert engine_config.parallel_config.world_size == 1, (
                "Ray is required if parallel_config.world_size > 1.")
            # from vllm.executor.gpu_executor import GPUExecutor
            from autoregressive.serve.gpu_executor import GPUExecutor
            executor_class = GPUExecutor

        # Create the LLM engine.
        engine = cls(
            **engine_config.to_dict(),
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            args=args,
        )
        return engine

    def __reduce__(self):
        # This is to ensure that the LLMEngine is not referenced in
        # the closure used to initialize Ray worker actors
        raise RuntimeError("LLMEngine should not be pickled!")

    def get_tokenizer(self) -> "PreTrainedTokenizer":
        return self.tokenizer.get_lora_tokenizer(None)

    def get_tokenizer_for_seq(self,
                              sequence: Sequence) -> "PreTrainedTokenizer":
        return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

    def _init_tokenizer(self, **tokenizer_init_kwargs):
        init_kwargs = dict(
            tokenizer_id=self.model_config.tokenizer,
            enable_lora=bool(self.lora_config),
            max_num_seqs=self.scheduler_config.max_num_seqs,
            max_input_length=None,
            tokenizer_mode=self.model_config.tokenizer_mode,
            trust_remote_code=self.model_config.trust_remote_code,
            revision=self.model_config.tokenizer_revision)
        init_kwargs.update(tokenizer_init_kwargs)
        self.tokenizer = get_tokenizer_group(
            self.parallel_config.tokenizer_pool_config, **init_kwargs)

    def _verify_args(self) -> None:
        self.model_config.verify_with_parallel_config(self.parallel_config)
        self.cache_config.verify_with_parallel_config(self.parallel_config)
        if self.lora_config:
            self.lora_config.verify_with_model_config(
Download .txt
gitextract_x82jk08x/

├── .gitignore
├── LICENSE
├── README.md
├── autoregressive/
│   ├── models/
│   │   ├── README.md
│   │   ├── dinov2_adapter.py
│   │   ├── generate.py
│   │   ├── gpt.py
│   │   ├── gpt_t2i.py
│   │   └── vit_adapter.py
│   ├── sample/
│   │   ├── sample_c2i.py
│   │   ├── sample_c2i_ddp.py
│   │   ├── sample_t2i.py
│   │   ├── sample_t2i_MR.py
│   │   └── sample_t2i_ddp.py
│   ├── serve/
│   │   ├── README.md
│   │   ├── fake_json/
│   │   │   ├── GPT-3B.json
│   │   │   ├── GPT-B.json
│   │   │   ├── GPT-L.json
│   │   │   ├── GPT-XL.json
│   │   │   └── GPT-XXL.json
│   │   ├── gpt_model.py
│   │   ├── gpu_executor.py
│   │   ├── llm.py
│   │   ├── llm_engine.py
│   │   ├── model_runner.py
│   │   ├── sample_c2i.py
│   │   ├── sampler.py
│   │   └── worker.py
│   ├── test/
│   │   ├── metric.py
│   │   ├── test_c2i.py
│   │   ├── test_ssim.py
│   │   └── test_t2i.py
│   └── train/
│       ├── extract_codes_c2i.py
│       ├── extract_codes_t2i.py
│       ├── extract_file_ade.py
│       ├── extract_file_cocostuff.py
│       ├── extract_file_imagenet.py
│       ├── extract_file_multigen.py
│       ├── train_c2i.py
│       ├── train_c2i_canny.py
│       ├── train_c2i_depth.py
│       ├── train_c2i_fsdp.py
│       ├── train_t2i.py
│       ├── train_t2i_canny.py
│       ├── train_t2i_depth.py
│       ├── train_t2i_depth_multiscale.py
│       ├── train_t2i_hed.py
│       ├── train_t2i_hed_multiscale.py
│       ├── train_t2i_lineart.py
│       ├── train_t2i_lineart_multiscale.py
│       ├── train_t2i_seg.py
│       └── train_t2i_seg_multiscale.py
├── condition/
│   ├── README.md
│   ├── canny.py
│   ├── depth.py
│   ├── example/
│   │   └── c2i/
│   │       ├── canny/
│   │       │   ├── 15000.npy
│   │       │   ├── 2312.npy
│   │       │   ├── 48850.npy
│   │       │   └── 650.npy
│   │       └── depth/
│   │           ├── 101.npy
│   │           ├── 10601.npy
│   │           ├── 4351.npy
│   │           └── 48901.npy
│   ├── hed.py
│   ├── lineart.py
│   ├── midas/
│   │   ├── depth.py
│   │   └── midas/
│   │       ├── __init__.py
│   │       ├── base_model.py
│   │       ├── blocks.py
│   │       ├── dpt_depth.py
│   │       ├── midas_net.py
│   │       ├── midas_net_custom.py
│   │       ├── transforms.py
│   │       └── vit.py
│   └── utils.py
├── create_npz.py
├── dataset/
│   ├── augmentation.py
│   ├── build.py
│   ├── coco.py
│   ├── imagenet.py
│   ├── openimage.py
│   ├── pexels.py
│   ├── t2i.py
│   ├── t2i_control.py
│   └── utils.py
├── demo/
│   ├── app.py
│   ├── app_depth.py
│   ├── app_edge.py
│   └── model.py
├── evaluations/
│   ├── ade20k_mIoU.py
│   ├── c2i/
│   │   ├── README.md
│   │   └── evaluator.py
│   ├── canny_f1score.py
│   ├── clean_fid.py
│   ├── cocostuff_mIoU.py
│   ├── depth_rmse.py
│   ├── hed_ssim.py
│   ├── lineart_ssim.py
│   └── t2i/
│       ├── PartiPrompts.tsv
│       ├── README.md
│       ├── coco_captions.csv
│       └── evaluation.py
├── language/
│   ├── README.md
│   ├── extract_t5_feature.py
│   └── t5.py
├── requirements.txt
├── scripts/
│   ├── autoregressive/
│   │   ├── extract_codes_c2i.sh
│   │   ├── extract_file_ade.sh
│   │   ├── extract_file_cocostuff.sh
│   │   ├── extract_file_imagenet.sh
│   │   ├── extract_file_multigen.sh
│   │   ├── sample_c2i.sh
│   │   ├── sample_t2i_coco.sh
│   │   ├── sample_t2i_parti.sh
│   │   ├── test_c2i.sh
│   │   ├── test_t2i.sh
│   │   ├── train_c2i.sh
│   │   ├── train_c2i_canny.sh
│   │   ├── train_c2i_depth.sh
│   │   ├── train_c2i_fsdp.sh
│   │   ├── train_t2i_canny.sh
│   │   ├── train_t2i_depth.sh
│   │   ├── train_t2i_depth_multiscale.sh
│   │   ├── train_t2i_hed.sh
│   │   ├── train_t2i_hed_multiscale.sh
│   │   ├── train_t2i_lineart.sh
│   │   ├── train_t2i_lineart_multiscale.sh
│   │   ├── train_t2i_seg.sh
│   │   ├── train_t2i_seg_multiscale.sh
│   │   ├── train_t2i_stage1.sh
│   │   └── train_t2i_stage2.sh
│   ├── language/
│   │   ├── extract_flan_t5_feat_laion_coco_stage1.sh
│   │   ├── extract_flan_t5_feat_stage2.sh
│   │   └── extract_flan_t5_feat_trunc_stage2.sh
│   └── tokenizer/
│       ├── reconstruction_consistency_decoder.sh
│       ├── reconstruction_vae.sh
│       ├── reconstruction_vq.sh
│       ├── reconstruction_vqgan.sh
│       ├── train_vq.sh
│       ├── train_vq_finetune.sh
│       ├── train_vq_finetune_continue.sh
│       └── val.sh
├── tokenizer/
│   ├── consistencydecoder/
│   │   ├── README.md
│   │   ├── cd_demo.py
│   │   └── reconstruction_cd_ddp.py
│   ├── tokenizer_image/
│   │   ├── cache/
│   │   │   └── vgg.pth
│   │   ├── discriminator.py
│   │   ├── discriminator_patchgan.py
│   │   ├── discriminator_stylegan.py
│   │   ├── lpips.py
│   │   ├── reconstruction_vq_ddp.py
│   │   ├── vq_demo.py
│   │   ├── vq_loss.py
│   │   ├── vq_model.py
│   │   ├── vq_model_hf.py
│   │   └── vq_train.py
│   ├── vae/
│   │   ├── README.md
│   │   ├── reconstruction_vae_ddp.py
│   │   └── sd_vae_demo.py
│   ├── validation/
│   │   └── val_ddp.py
│   └── vqgan/
│       ├── README.md
│       ├── configs/
│       │   ├── vqgan_imagenet_f16_1024.yaml
│       │   ├── vqgan_imagenet_f16_16384.yaml
│       │   ├── vqgan_openimage_f8_16384.yaml
│       │   └── vqgan_openimage_f8_256.yaml
│       ├── layer.py
│       ├── model.py
│       ├── quantize.py
│       ├── reconstruction_vqgan_ddp.py
│       └── taming_vqgan_demo.py
├── tools/
│   ├── check_image_codes.py
│   ├── convert_pytorch_lightning_to_torch.py
│   ├── draw_figure.py
│   ├── imagenet_en_cn.py
│   ├── openimage_json.py
│   ├── push_gpt_to_hf.py
│   └── push_vae_to_hf.py
└── utils/
    ├── data.py
    ├── deepspeed.py
    ├── distributed.py
    ├── drop_path.py
    ├── ema.py
    ├── logger.py
    └── video.py
Download .txt
SYMBOL INDEX (813 symbols across 110 files)

FILE: autoregressive/models/dinov2_adapter.py
  class Dinov2_Adapter (line 8) | class Dinov2_Adapter(nn.Module):
    method __init__ (line 9) | def __init__(self, input_dim=1, output_dim=768, attention=False, pool=...
    method to_patch14 (line 16) | def to_patch14(self, input):
    method forward (line 26) | def forward(self, x):

FILE: autoregressive/models/generate.py
  function top_k_top_p_filtering (line 17) | def top_k_top_p_filtering(
  function sample (line 59) | def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float...
  function logits_to_probs (line 77) | def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, ...
  function prefill (line 85) | def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_...
  function decode_one_token (line 97) | def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cf...
  function decode_n_tokens (line 113) | def decode_n_tokens(
  function generate (line 135) | def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0,...

FILE: autoregressive/models/gpt.py
  function get_causal_mask (line 19) | def get_causal_mask(seq_length):
  function find_multiple (line 25) | def find_multiple(n: int, k: int):
  class ModelArgs (line 31) | class ModelArgs:
  class LabelEmbedder (line 66) | class LabelEmbedder(nn.Module):
    method __init__ (line 70) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method token_drop (line 77) | def token_drop(self, labels, force_drop_ids=None):
    method forward (line 88) | def forward(self, labels, train, force_drop_ids=None):
  class ConditionEmbedder (line 99) | class ConditionEmbedder(nn.Module):
    method __init__ (line 103) | def __init__(self, in_channels, hidden_size, uncond_prob, token_num=12...
    method token_drop (line 109) | def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
    method forward (line 122) | def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
  class CaptionEmbedder (line 132) | class CaptionEmbedder(nn.Module):
    method __init__ (line 136) | def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
    method token_drop (line 142) | def token_drop(self, caption, force_drop_ids=None):
    method forward (line 153) | def forward(self, caption, train, force_drop_ids=None):
  class MLP (line 161) | class MLP(nn.Module):
    method __init__ (line 162) | def __init__(self, in_features, hidden_features, out_features):
    method forward (line 173) | def forward(self, x):
  class RMSNorm (line 183) | class RMSNorm(torch.nn.Module):
    method __init__ (line 184) | def __init__(self, dim: int, eps: float = 1e-5):
    method _norm (line 189) | def _norm(self, x):
    method forward (line 192) | def forward(self, x):
  class FeedForward (line 197) | class FeedForward(nn.Module):
    method __init__ (line 198) | def __init__(self, config: ModelArgs):
    method forward (line 212) | def forward(self, x):
  class KVCache (line 216) | class KVCache(nn.Module):
    method __init__ (line 217) | def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, d...
    method update (line 223) | def update(self, input_pos, k_val, v_val):
  class Attention (line 234) | class Attention(nn.Module):
    method __init__ (line 235) | def __init__(self, config: ModelArgs):
    method forward (line 253) | def forward(
  class TransformerBlock (line 290) | class TransformerBlock(nn.Module):
    method __init__ (line 291) | def __init__(self, config: ModelArgs, drop_path: float):
    method forward (line 299) | def forward(
  class Transformer (line 306) | class Transformer(nn.Module):
    method __init__ (line 307) | def __init__(self, config: ModelArgs):
    method initialize_weights (line 368) | def initialize_weights(self):
    method _init_weights (line 373) | def _init_weights(self, module):
    method setup_caches (line 382) | def setup_caches(self, max_batch_size, max_seq_length, dtype):
    method forward (line 400) | def forward(
    method get_fsdp_wrap_module_list (line 468) | def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
  function precompute_freqs_cis (line 477) | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, c...
  function precompute_freqs_cis_2d (line 487) | def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 100...
  function apply_rotary_emb (line 503) | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
  function GPT_7B (line 521) | def GPT_7B(**kwargs):
  function GPT_3B (line 524) | def GPT_3B(**kwargs):
  function GPT_1B (line 527) | def GPT_1B(**kwargs):
  function GPT_XXXL (line 531) | def GPT_XXXL(**kwargs):
  function GPT_XXL (line 534) | def GPT_XXL(**kwargs):
  function GPT_XL (line 537) | def GPT_XL(**kwargs):
  function GPT_L (line 540) | def GPT_L(**kwargs):
  function GPT_B (line 543) | def GPT_B(**kwargs):

FILE: autoregressive/models/gpt_t2i.py
  function get_causal_mask (line 20) | def get_causal_mask(seq_length):
  function find_multiple (line 26) | def find_multiple(n: int, k: int):
  class ModelArgs (line 32) | class ModelArgs:
  class LabelEmbedder (line 67) | class LabelEmbedder(nn.Module):
    method __init__ (line 71) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method token_drop (line 78) | def token_drop(self, labels, force_drop_ids=None):
    method forward (line 89) | def forward(self, labels, train, force_drop_ids=None):
  class ConditionEmbedder (line 100) | class ConditionEmbedder(nn.Module):
    method __init__ (line 104) | def __init__(self, in_channels, hidden_size, uncond_prob, token_num=12...
    method token_drop (line 110) | def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
    method forward (line 123) | def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
  class CaptionEmbedder (line 133) | class CaptionEmbedder(nn.Module):
    method __init__ (line 137) | def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
    method token_drop (line 143) | def token_drop(self, caption, force_drop_ids=None):
    method forward (line 154) | def forward(self, caption, train, force_drop_ids=None):
  class MLP (line 165) | class MLP(nn.Module):
    method __init__ (line 166) | def __init__(self, in_features, hidden_features, out_features):
    method forward (line 177) | def forward(self, x):
  class RMSNorm (line 187) | class RMSNorm(torch.nn.Module):
    method __init__ (line 188) | def __init__(self, dim: int, eps: float = 1e-5):
    method _norm (line 193) | def _norm(self, x):
    method forward (line 196) | def forward(self, x):
  class FeedForward (line 201) | class FeedForward(nn.Module):
    method __init__ (line 202) | def __init__(self, config: ModelArgs):
    method forward (line 216) | def forward(self, x):
  class KVCache (line 220) | class KVCache(nn.Module):
    method __init__ (line 221) | def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, d...
    method update (line 227) | def update(self, input_pos, k_val, v_val):
  class Attention (line 238) | class Attention(nn.Module):
    method __init__ (line 239) | def __init__(self, config: ModelArgs):
    method forward (line 257) | def forward(
  class TransformerBlock (line 294) | class TransformerBlock(nn.Module):
    method __init__ (line 295) | def __init__(self, config: ModelArgs, drop_path: float):
    method forward (line 303) | def forward(
  class Transformer (line 310) | class Transformer(nn.Module):
    method __init__ (line 311) | def __init__(self, config: ModelArgs):
    method initialize_weights (line 372) | def initialize_weights(self):
    method _init_weights (line 381) | def _init_weights(self, module):
    method setup_caches (line 391) | def setup_caches(self, max_batch_size, max_seq_length, dtype):
    method forward (line 409) | def forward(
    method get_fsdp_wrap_module_list (line 487) | def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
  function precompute_freqs_cis (line 496) | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, c...
  function precompute_freqs_cis_2d (line 506) | def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 100...
  function apply_rotary_emb (line 522) | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
  function GPT_7B (line 540) | def GPT_7B(**kwargs):
  function GPT_3B (line 543) | def GPT_3B(**kwargs):
  function GPT_1B (line 546) | def GPT_1B(**kwargs):
  function GPT_XXXL (line 550) | def GPT_XXXL(**kwargs):
  function GPT_XXL (line 553) | def GPT_XXL(**kwargs):
  function GPT_XL (line 556) | def GPT_XL(**kwargs):
  function GPT_L (line 559) | def GPT_L(**kwargs):
  function GPT_B (line 562) | def GPT_B(**kwargs):

FILE: autoregressive/models/vit_adapter.py
  class ViT_Adapter (line 8) | class ViT_Adapter(nn.Module):
    method __init__ (line 9) | def __init__(self, input_dim=3, output_dim=768, attention=False, pool=...
    method forward (line 13) | def forward(self, x):

FILE: autoregressive/sample/sample_c2i.py
  function main (line 27) | def main(args):

FILE: autoregressive/sample/sample_c2i_ddp.py
  function create_npz_from_sample_folder (line 21) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function main (line 38) | def main(args):

FILE: autoregressive/sample/sample_t2i.py
  function main (line 34) | def main(args):

FILE: autoregressive/sample/sample_t2i_MR.py
  function resize_image_to_16_multiple (line 37) | def resize_image_to_16_multiple(image_path, condition_type='seg'):
  function main (line 51) | def main(args):

FILE: autoregressive/sample/sample_t2i_ddp.py
  function main (line 26) | def main(args):

FILE: autoregressive/serve/gpt_model.py
  function find_multiple (line 18) | def find_multiple(n: int, k: int):
  class ModelArgs (line 24) | class ModelArgs:
  class LabelEmbedder (line 50) | class LabelEmbedder(nn.Module):
    method __init__ (line 54) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method forward (line 73) | def forward(self, labels):
  class FeedForward (line 98) | class FeedForward(nn.Module):
    method __init__ (line 99) | def __init__(self, config: ModelArgs):
    method forward (line 119) | def forward(self, x):
  class Attention (line 127) | class Attention(nn.Module):
    method __init__ (line 128) | def __init__(self, config: ModelArgs):
    method forward (line 154) | def forward(
  class AttentionMonkeyPatch (line 178) | class AttentionMonkeyPatch(Attention):
    method __init__ (line 185) | def __init__(self, config: ModelArgs):
    method forward (line 191) | def forward(
  class TransformerBlock (line 227) | class TransformerBlock(nn.Module):
    method __init__ (line 228) | def __init__(self, config: ModelArgs):
    method forward (line 238) | def forward(self, x: torch.Tensor, positions: torch.Tensor, kv_cache: ...
  class Transformer (line 244) | class Transformer(nn.Module):
    method __init__ (line 245) | def __init__(self, config: ModelArgs):
    method forward (line 273) | def forward(
    method compute_logits (line 297) | def compute_logits(self, hidden_states: torch.Tensor,
    method sample (line 302) | def sample(
    method custom_load_state_dict (line 311) | def custom_load_state_dict(self, model_weights):
  function precompute_freqs_cis (line 334) | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, c...
  function precompute_freqs_cis_2d (line 344) | def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 100...
  function apply_rotary_emb (line 360) | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
  function apply_rotary_emb_bs (line 373) | def apply_rotary_emb_bs(x: torch.Tensor, freqs_cis: torch.Tensor):
  function GPT_7B (line 390) | def GPT_7B(**kwargs):
  function GPT_3B (line 393) | def GPT_3B(**kwargs):
  function GPT_1B (line 396) | def GPT_1B(**kwargs):
  function GPT_XXXL (line 400) | def GPT_XXXL(**kwargs):
  function GPT_XXL (line 403) | def GPT_XXL(**kwargs):
  function GPT_XL (line 406) | def GPT_XL(**kwargs):
  function GPT_L (line 409) | def GPT_L(**kwargs):
  function GPT_B (line 412) | def GPT_B(**kwargs):

FILE: autoregressive/serve/gpu_executor.py
  class GPUExecutor (line 17) | class GPUExecutor(ExecutorBase):
    method __init__ (line 18) | def __init__(
    method _init_executor (line 44) | def _init_executor(self) -> None:
    method _init_non_spec_worker (line 55) | def _init_non_spec_worker(self):
    method _init_spec_worker (line 83) | def _init_spec_worker(self):
    method determine_num_available_blocks (line 136) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method initialize_cache (line 142) | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
    method execute_model (line 153) | def execute_model(
    method add_lora (line 170) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 174) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 178) | def list_loras(self) -> Set[int]:
    method check_health (line 181) | def check_health(self) -> None:
  class GPUExecutorAsync (line 187) | class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
    method execute_model_async (line 189) | async def execute_model_async(

FILE: autoregressive/serve/llm.py
  class LLM (line 22) | class LLM:
    method __init__ (line 82) | def __init__(
    method get_tokenizer (line 128) | def get_tokenizer(
    method set_tokenizer (line 132) | def set_tokenizer(
    method generate (line 138) | def generate(
    method _add_request (line 221) | def _add_request(
    method _run_engine (line 238) | def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:

FILE: autoregressive/serve/llm_engine.py
  function _load_generation_config_dict (line 40) | def _load_generation_config_dict(model_config: ModelConfig):
  class LLMEngine (line 51) | class LLMEngine:
    method __init__ (line 85) | def __init__(
    method _initialize_kv_caches (line 234) | def _initialize_kv_caches(self) -> None:
    method from_engine_args (line 255) | def from_engine_args(
    method __reduce__ (line 293) | def __reduce__(self):
    method get_tokenizer (line 298) | def get_tokenizer(self) -> "PreTrainedTokenizer":
    method get_tokenizer_for_seq (line 301) | def get_tokenizer_for_seq(self,
    method _init_tokenizer (line 305) | def _init_tokenizer(self, **tokenizer_init_kwargs):
    method _verify_args (line 318) | def _verify_args(self) -> None:
    method encode_request (line 326) | def encode_request(
    method add_request (line 340) | def add_request(
    method abort_request (line 439) | def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
    method get_model_config (line 458) | def get_model_config(self) -> ModelConfig:
    method get_num_unfinished_requests (line 462) | def get_num_unfinished_requests(self) -> int:
    method has_unfinished_requests (line 466) | def has_unfinished_requests(self) -> bool:
    method _process_model_outputs (line 470) | def _process_model_outputs(
    method step (line 511) | def step(self) -> List[RequestOutput]:
    method do_log_stats (line 583) | def do_log_stats(self) -> None:
    method _get_stats (line 588) | def _get_stats(self,
    method add_lora (line 661) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 664) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 667) | def list_loras(self) -> List[int]:
    method check_health (line 670) | def check_health(self) -> None:

FILE: autoregressive/serve/model_runner.py
  class PreparePromptMetadata (line 43) | class PreparePromptMetadata(NamedTuple):
    method empty (line 56) | def empty(cls):
  class PrepareDecodeMetadata (line 71) | class PrepareDecodeMetadata(NamedTuple):
    method empty (line 81) | def empty(cls):
  class BatchType (line 94) | class BatchType(IntEnum):
  class ModelRunner (line 103) | class ModelRunner:
    method __init__ (line 105) | def __init__(
    method load_model (line 161) | def load_model(self, args) -> None:
    method set_block_size (line 237) | def set_block_size(self, block_size: int) -> None:
    method get_max_block_per_batch (line 244) | def get_max_block_per_batch(self) -> int:
    method _prepare_prompt (line 248) | def _prepare_prompt(
    method _prepare_decode (line 448) | def _prepare_decode(
    method _prepare_sample (line 574) | def _prepare_sample(
    method prepare_input_tensors (line 676) | def prepare_input_tensors(
    method execute_model (line 845) | def execute_model(
    method profile_run (line 889) | def profile_run(self) -> None:
    method remove_all_loras (line 955) | def remove_all_loras(self) -> bool:
    method set_active_loras (line 960) | def set_active_loras(self, lora_requests: Set[LoRARequest],
    method add_lora (line 966) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 971) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 976) | def list_loras(self) -> Set[int]:
    method capture_model (line 982) | def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
    method __del__ (line 1083) | def __del__(self) -> None:
    method vocab_size (line 1094) | def vocab_size(self) -> int:
  class CUDAGraphRunner (line 1098) | class CUDAGraphRunner:
    method __init__ (line 1100) | def __init__(self, model: nn.Module):
    method graph (line 1108) | def graph(self):
    method capture (line 1112) | def capture(
    method forward (line 1162) | def forward(
    method __call__ (line 1188) | def __call__(self, *args, **kwargs):
  function _maybe_pynccl (line 1193) | def _maybe_pynccl():
  function _get_graph_batch_size (line 1202) | def _get_graph_batch_size(batch_size: int) -> int:
  function _prepare_fake_inputs (line 1217) | def _prepare_fake_inputs(

FILE: autoregressive/serve/sample_c2i.py
  function main (line 13) | def main(args):

FILE: autoregressive/serve/sampler.py
  class Sampler (line 17) | class Sampler(nn.Module):
    method __init__ (line 38) | def __init__(self, cfg_scale=1.0):
    method forward (line 46) | def forward(
    method _should_modify_greedy_probs_inplace (line 128) | def _should_modify_greedy_probs_inplace(self) -> bool:
  function _get_bin_counts_and_mask (line 143) | def _get_bin_counts_and_mask(
  function _apply_min_tokens_penalty (line 160) | def _apply_min_tokens_penalty(
  function _apply_penalties (line 207) | def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.T...
  function _apply_top_k_top_p (line 230) | def _apply_top_k_top_p(
  function _apply_min_p (line 262) | def _apply_min_p(
  function _greedy_sample (line 279) | def _greedy_sample(
  function _random_sample (line 298) | def _random_sample(
  function _beam_search_sample (line 325) | def _beam_search_sample(
  function _multinomial (line 383) | def _multinomial(
  function _sample_with_torch (line 410) | def _sample_with_torch(
  function _sample_with_triton_kernel (line 520) | def _sample_with_triton_kernel(
  function _sample (line 600) | def _sample(
  function _get_ranks (line 618) | def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
  function _get_logprobs (line 637) | def _get_logprobs(
  function _modify_greedy_probs_inplace (line 775) | def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Te...
  function _build_sampler_output (line 826) | def _build_sampler_output(

FILE: autoregressive/serve/worker.py
  class Worker (line 27) | class Worker(WorkerBase):
    method __init__ (line 35) | def __init__(
    method init_device (line 89) | def init_device(self) -> None:
    method load_model (line 117) | def load_model(self, args):
    method determine_num_available_blocks (line 121) | def determine_num_available_blocks(self) -> Tuple[int, int]:
    method initialize_cache (line 166) | def initialize_cache(self, num_gpu_blocks: int,
    method _init_cache_engine (line 182) | def _init_cache_engine(self):
    method _warm_up_model (line 189) | def _warm_up_model(self) -> None:
    method cache_swap (line 196) | def cache_swap(
    method execute_model (line 212) | def execute_model(
    method add_lora (line 257) | def add_lora(self, lora_request: LoRARequest) -> bool:
    method remove_lora (line 260) | def remove_lora(self, lora_id: int) -> bool:
    method list_loras (line 263) | def list_loras(self) -> Set[int]:
    method max_model_len (line 267) | def max_model_len(self) -> int:
    method vocab_size (line 271) | def vocab_size(self) -> int:
    method get_cache_block_size_bytes (line 274) | def get_cache_block_size_bytes(self) -> int:
  function init_worker_distributed_environment (line 282) | def init_worker_distributed_environment(
  function _check_if_gpu_supports_dtype (line 322) | def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
  function raise_if_cache_size_invalid (line 336) | def raise_if_cache_size_invalid(num_gpu_blocks, block_size,

FILE: autoregressive/test/metric.py
  class SSIM (line 7) | class SSIM:
    method __init__ (line 8) | def __init__(self, data_range=1.0):
    method update (line 12) | def update(self, img1, img2):
    method calculate (line 17) | def calculate(self):
  class F1score (line 24) | class F1score:
    method __init__ (line 25) | def __init__(self, threshold=128):
    method update (line 30) | def update(self, img1, img2):
    method calculate (line 45) | def calculate(self):
  class RMSE (line 49) | class RMSE:
    method __init__ (line 50) | def __init__(self):
    method update (line 54) | def update(self, img1, img2):
    method calculate (line 64) | def calculate(self):

FILE: autoregressive/test/test_c2i.py
  function create_npz_from_sample_folder (line 37) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function main (line 53) | def main(args):

FILE: autoregressive/test/test_t2i.py
  function main (line 43) | def main(args):

FILE: autoregressive/train/extract_codes_c2i.py
  function main (line 25) | def main(args):

FILE: autoregressive/train/extract_codes_t2i.py
  class CustomDataset (line 25) | class CustomDataset(Dataset):
    method __init__ (line 26) | def __init__(self, lst_dir, start, end, transform):
    method __len__ (line 41) | def __len__(self):
    method __getitem__ (line 44) | def __getitem__(self, index):
  function main (line 56) | def main(args):

FILE: autoregressive/train/extract_file_ade.py
  function collate_fn (line 53) | def collate_fn(examples):
  function main (line 79) | def main(args):

FILE: autoregressive/train/extract_file_cocostuff.py
  function collate_fn (line 53) | def collate_fn(examples):
  function main (line 79) | def main(args):

FILE: autoregressive/train/extract_file_imagenet.py
  function main (line 29) | def main(args):

FILE: autoregressive/train/extract_file_multigen.py
  function collate_fn (line 59) | def collate_fn(examples):
  function main (line 79) | def main(args):

FILE: autoregressive/train/train_c2i.py
  function creat_optimizer (line 28) | def creat_optimizer(model, weight_decay, learning_rate, betas, logger):
  function main (line 57) | def main(args):

FILE: autoregressive/train/train_c2i_canny.py
  function creat_optimizer (line 31) | def creat_optimizer(model, weight_decay, learning_rate, betas, logger):
  function main (line 59) | def main(args):

FILE: autoregressive/train/train_c2i_depth.py
  function creat_optimizer (line 33) | def creat_optimizer(model, weight_decay, learning_rate, betas, logger):
  function main (line 60) | def main(args):

FILE: autoregressive/train/train_c2i_fsdp.py
  function setup_fsdp_sync (line 31) | def setup_fsdp_sync(model: nn.Module, args: argparse.Namespace, device) ...
  function creat_optimizer_by_name (line 67) | def creat_optimizer_by_name(model, weight_decay, learning_rate, betas, g...
  function main (line 102) | def main(args):

FILE: autoregressive/train/train_t2i.py
  function main (line 27) | def main(args):

FILE: autoregressive/train/train_t2i_canny.py
  function main (line 38) | def main(args):

FILE: autoregressive/train/train_t2i_depth.py
  function main (line 38) | def main(args):

FILE: autoregressive/train/train_t2i_depth_multiscale.py
  function random_sample_scale (line 44) | def random_sample_scale(image, condition=None):
  function main (line 59) | def main(args):

FILE: autoregressive/train/train_t2i_hed.py
  function main (line 38) | def main(args):

FILE: autoregressive/train/train_t2i_hed_multiscale.py
  function random_sample_scale (line 42) | def random_sample_scale(image, condition=None):
  function main (line 57) | def main(args):

FILE: autoregressive/train/train_t2i_lineart.py
  function main (line 39) | def main(args):

FILE: autoregressive/train/train_t2i_lineart_multiscale.py
  function random_sample_scale (line 44) | def random_sample_scale(image, condition=None):
  function main (line 58) | def main(args):

FILE: autoregressive/train/train_t2i_seg.py
  function main (line 39) | def main(args):

FILE: autoregressive/train/train_t2i_seg_multiscale.py
  function random_sample_scale (line 43) | def random_sample_scale(image, condition=None):
  function main (line 57) | def main(args):

FILE: condition/canny.py
  class CannyDetector (line 6) | class CannyDetector:
    method __call__ (line 7) | def __call__(self, img, low_threshold=100, high_threshold=200):

FILE: condition/depth.py
  class Depth (line 6) | class Depth:
    method __init__ (line 7) | def __init__(self, device):
    method __call__ (line 10) | def __call__(self, input_image):

FILE: condition/hed.py
  class DoubleConvBlock (line 17) | class DoubleConvBlock(torch.nn.Module):
    method __init__ (line 18) | def __init__(self, input_channel, output_channel, layer_number):
    method __call__ (line 26) | def __call__(self, x, down_sampling=False):
  class ControlNetHED_Apache2 (line 36) | class ControlNetHED_Apache2(torch.nn.Module):
    method __init__ (line 37) | def __init__(self):
    method __call__ (line 46) | def __call__(self, x):
  class HEDdetector (line 56) | class HEDdetector(torch.nn.Module):
    method __init__ (line 57) | def __init__(self):
    method __call__ (line 67) | def __call__(self, input_image):
  function nms (line 84) | def nms(x, t, s):

FILE: condition/lineart.py
  class ResidualBlock (line 9) | class ResidualBlock(nn.Module):
    method __init__ (line 10) | def __init__(self, in_features):
    method forward (line 24) | def forward(self, x):
  class LineArt (line 26) | class LineArt(nn.Module):
    method __init__ (line 27) | def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmo...
    method forward (line 74) | def forward(self, x, cond=None):

FILE: condition/midas/depth.py
  function disabled_train (line 31) | def disabled_train(self, mode=True):
  function load_midas_transform (line 37) | def load_midas_transform(model_type):
  function load_model (line 82) | def load_model(model_type):
  class MiDaSInference (line 150) | class MiDaSInference(nn.Module):
    method __init__ (line 163) | def __init__(self, model_type):
    method forward (line 170) | def forward(self, x):
  class MidasDetector (line 176) | class MidasDetector:
    method __init__ (line 177) | def __init__(self,device=torch.device('cuda:0'), model_type="dpt_hybri...
    method __call__ (line 181) | def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):

FILE: condition/midas/midas/base_model.py
  class BaseModel (line 4) | class BaseModel(torch.nn.Module):
    method load (line 5) | def load(self, path):

FILE: condition/midas/midas/blocks.py
  function _make_encoder (line 11) | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=F...
  function _make_scratch (line 49) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
  function _make_pretrained_efficientnet_lite3 (line 78) | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
  function _make_efficientnet_backbone (line 88) | def _make_efficientnet_backbone(effnet):
  function _make_resnet_backbone (line 101) | def _make_resnet_backbone(resnet):
  function _make_pretrained_resnext101_wsl (line 114) | def _make_pretrained_resnext101_wsl(use_pretrained):
  class Interpolate (line 120) | class Interpolate(nn.Module):
    method __init__ (line 124) | def __init__(self, scale_factor, mode, align_corners=False):
    method forward (line 138) | def forward(self, x):
  class ResidualConvUnit (line 155) | class ResidualConvUnit(nn.Module):
    method __init__ (line 159) | def __init__(self, features):
    method forward (line 177) | def forward(self, x):
  class FeatureFusionBlock (line 194) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 198) | def __init__(self, features):
    method forward (line 209) | def forward(self, *xs):
  class ResidualConvUnit_custom (line 231) | class ResidualConvUnit_custom(nn.Module):
    method __init__ (line 235) | def __init__(self, features, activation, bn):
    method forward (line 263) | def forward(self, x):
  class FeatureFusionBlock_custom (line 291) | class FeatureFusionBlock_custom(nn.Module):
    method __init__ (line 295) | def __init__(self, features, activation, deconv=False, bn=False, expan...
    method forward (line 320) | def forward(self, *xs):

FILE: condition/midas/midas/dpt_depth.py
  function _make_fusion_block (line 15) | def _make_fusion_block(features, use_bn):
  class DPT (line 26) | class DPT(BaseModel):
    method __init__ (line 27) | def __init__(
    method forward (line 67) | def forward(self, x):
  class DPTDepthModel (line 88) | class DPTDepthModel(DPT):
    method __init__ (line 89) | def __init__(self, path=None, non_negative=True, **kwargs):
    method forward (line 107) | def forward(self, x):

FILE: condition/midas/midas/midas_net.py
  class MidasNet (line 12) | class MidasNet(BaseModel):
    method __init__ (line 16) | def __init__(self, path=None, features=256, non_negative=True):
    method forward (line 49) | def forward(self, x):

FILE: condition/midas/midas/midas_net_custom.py
  class MidasNet_small (line 12) | class MidasNet_small(BaseModel):
    method __init__ (line 16) | def __init__(self, path=None, features=64, backbone="efficientnet_lite...
    method forward (line 73) | def forward(self, x):
  function fuse_model (line 109) | def fuse_model(m):

FILE: condition/midas/midas/transforms.py
  function apply_min_size (line 6) | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AR...
  class Resize (line 48) | class Resize(object):
    method __init__ (line 52) | def __init__(
    method constrain_to_multiple_of (line 94) | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
    method get_size (line 105) | def get_size(self, width, height):
    method __call__ (line 162) | def __call__(self, sample):
  class NormalizeImage (line 197) | class NormalizeImage(object):
    method __init__ (line 201) | def __init__(self, mean, std):
    method __call__ (line 205) | def __call__(self, sample):
  class PrepareForNet (line 211) | class PrepareForNet(object):
    method __init__ (line 215) | def __init__(self):
    method __call__ (line 218) | def __call__(self, sample):

FILE: condition/midas/midas/vit.py
  class Slice (line 9) | class Slice(nn.Module):
    method __init__ (line 10) | def __init__(self, start_index=1):
    method forward (line 14) | def forward(self, x):
  class AddReadout (line 18) | class AddReadout(nn.Module):
    method __init__ (line 19) | def __init__(self, start_index=1):
    method forward (line 23) | def forward(self, x):
  class ProjectReadout (line 31) | class ProjectReadout(nn.Module):
    method __init__ (line 32) | def __init__(self, in_features, start_index=1):
    method forward (line 38) | def forward(self, x):
  class Transpose (line 45) | class Transpose(nn.Module):
    method __init__ (line 46) | def __init__(self, dim0, dim1):
    method forward (line 51) | def forward(self, x):
  function forward_vit (line 56) | def forward_vit(pretrained, x):
  function _resize_pos_embed (line 100) | def _resize_pos_embed(self, posemb, gs_h, gs_w):
  function forward_flex (line 117) | def forward_flex(self, x):
  function get_activation (line 159) | def get_activation(name):
  function get_readout_oper (line 166) | def get_readout_oper(vit_features, features, use_readout, start_index=1):
  function _make_vit_b16_backbone (line 183) | def _make_vit_b16_backbone(
  function _make_pretrained_vitl16_384 (line 297) | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_vitb16_384 (line 310) | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_deitb16_384 (line 319) | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks...
  function _make_pretrained_deitb16_distil_384 (line 328) | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore"...
  function _make_vit_b_rn50_backbone (line 343) | def _make_vit_b_rn50_backbone(
  function _make_pretrained_vitb_rn50_384 (line 478) | def _make_pretrained_vitb_rn50_384(

FILE: condition/utils.py
  function HWC3 (line 9) | def HWC3(x):
  function resize_image (line 28) | def resize_image(input_image, resolution):

FILE: create_npz.py
  function create_npz_from_sample_folder (line 8) | def create_npz_from_sample_folder(sample_dir, num=50_000):

FILE: dataset/augmentation.py
  function center_crop_arr (line 8) | def center_crop_arr(pil_image, image_size):
  function random_crop_arr (line 29) | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_f...

FILE: dataset/build.py
  function build_dataset (line 8) | def build_dataset(args, **kwargs):

FILE: dataset/coco.py
  class SingleFolderDataset (line 7) | class SingleFolderDataset(Dataset):
    method __init__ (line 8) | def __init__(self, directory, transform=None):
    method __len__ (line 15) | def __len__(self):
    method __getitem__ (line 18) | def __getitem__(self, idx):
  function build_coco (line 26) | def build_coco(args, transform):

FILE: dataset/imagenet.py
  class CustomDataset (line 9) | class CustomDataset(Dataset):
    method __init__ (line 10) | def __init__(self, feature_dir, label_dir, condition_dir=None, get_con...
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 52) | def __getitem__(self, idx):
  function build_imagenet (line 104) | def build_imagenet(args, transform):
  function build_imagenet_code (line 107) | def build_imagenet_code(args):

FILE: dataset/openimage.py
  class DatasetJson (line 10) | class DatasetJson(Dataset):
    method __init__ (line 11) | def __init__(self, data_path, transform=None):
    method __len__ (line 20) | def __len__(self):
    method __getitem__ (line 23) | def __getitem__(self, idx):
    method getdata (line 32) | def getdata(self, idx):
  function build_openimage (line 41) | def build_openimage(args, transform):

FILE: dataset/pexels.py
  function build_pexels (line 3) | def build_pexels(args, transform):

FILE: dataset/t2i.py
  class Text2ImgDatasetImg (line 10) | class Text2ImgDatasetImg(Dataset):
    method __init__ (line 11) | def __init__(self, lst_dir, face_lst_dir, transform):
    method __len__ (line 39) | def __len__(self):
    method __getitem__ (line 42) | def __getitem__(self, index):
  class Text2ImgDataset (line 50) | class Text2ImgDataset(Dataset):
    method __init__ (line 51) | def __init__(self, args, transform):
    method __len__ (line 85) | def __len__(self):
    method dummy_data (line 88) | def dummy_data(self):
    method __getitem__ (line 95) | def __getitem__(self, index):
  class Text2ImgDatasetCode (line 138) | class Text2ImgDatasetCode(Dataset):
    method __init__ (line 139) | def __init__(self, args):
  function build_t2i_image (line 145) | def build_t2i_image(args, transform):
  function build_t2i (line 148) | def build_t2i(args, transform):
  function build_t2i_code (line 151) | def build_t2i_code(args):

FILE: dataset/t2i_control.py
  class T2IControlCode (line 36) | class T2IControlCode(Dataset):
    method __init__ (line 37) | def __init__(self, args):
    method __len__ (line 64) | def __len__(self):
    method dummy_data (line 67) | def dummy_data(self):
    method collate_fn (line 74) | def collate_fn(self, examples):
    method __getitem__ (line 104) | def __getitem__(self, index):
  function build_t2i_control_code (line 165) | def build_t2i_control_code(args):

FILE: dataset/utils.py
  function get_reward_model (line 20) | def get_reward_model(task='segmentation', model_path='mmseg::upernet/upe...
  function get_reward_loss (line 44) | def get_reward_loss(predictions, labels, task='segmentation', **args):
  function image_grid (line 65) | def image_grid(imgs, rows, cols):
  function map_color_to_index (line 77) | def map_color_to_index(image, dataset='limingcv/Captioned_ADE20K'):
  function seg_label_transform (line 105) | def seg_label_transform(
  function depth_label_transform (line 142) | def depth_label_transform(
  function edge_label_transform (line 156) | def edge_label_transform(labels, dataset_name):
  function label_transform (line 160) | def label_transform(labels, task, dataset_name, **args):
  function group_random_crop (line 171) | def group_random_crop(images, resolution):
  class ResidualBlock (line 191) | class ResidualBlock(nn.Module):
    method __init__ (line 192) | def __init__(self, in_features):
    method forward (line 206) | def forward(self, x):
  class LineDrawingModel (line 210) | class LineDrawingModel(nn.Module):
    method __init__ (line 211) | def __init__(self, input_nc=3, output_nc=1, n_residual_blocks=3, sigmo...
    method forward (line 258) | def forward(self, x, cond=None):
  class DoubleConvBlock (line 269) | class DoubleConvBlock(torch.nn.Module):
    method __init__ (line 270) | def __init__(self, input_channel, output_channel, layer_number):
    method __call__ (line 278) | def __call__(self, x, down_sampling=False):
  class ControlNetHED_Apache2 (line 288) | class ControlNetHED_Apache2(torch.nn.Module):
    method __init__ (line 289) | def __init__(self):
    method __call__ (line 298) | def __call__(self, x):
  class HEDdetector (line 308) | class HEDdetector(nn.Module):
    method __init__ (line 309) | def __init__(self, model_path):
    method __call__ (line 316) | def __call__(self, input_image):

FILE: demo/app_depth.py
  function randomize_seed_fn (line 5) | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
  function create_demo (line 27) | def create_demo(process):

FILE: demo/app_edge.py
  function randomize_seed_fn (line 5) | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
  function create_demo (line 27) | def create_demo(process):

FILE: demo/model.py
  class Model (line 21) | class Model:
    method __init__ (line 22) | def __init__(self):
    method to (line 34) | def to(self, device):
    method load_vq (line 37) | def load_vq(self):
    method load_gpt (line 48) | def load_gpt(self, condition_type='edge'):
    method load_gpt_weight (line 66) | def load_gpt_weight(self, condition_type='edge'):
    method load_t5 (line 77) | def load_t5(self):
    method process_edge (line 92) | def process_edge(
    method process_depth (line 192) | def process_depth(

FILE: evaluations/ade20k_mIoU.py
  function main (line 9) | def main():

FILE: evaluations/c2i/evaluator.py
  function main (line 27) | def main():
  class InvalidFIDException (line 75) | class InvalidFIDException(Exception):
  class FIDStatistics (line 79) | class FIDStatistics:
    method __init__ (line 80) | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
    method frechet_distance (line 84) | def frechet_distance(self, other, eps=1e-6):
  class Evaluator (line 130) | class Evaluator:
    method __init__ (line 131) | def __init__(
    method warmup (line 147) | def warmup(self):
    method read_activations (line 150) | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndar...
    method compute_activations (line 154) | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[...
    method read_statistics (line 176) | def read_statistics(
    method compute_statistics (line 186) | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
    method compute_inception_score (line 191) | def compute_inception_score(self, activations: np.ndarray, split_size:...
    method compute_prec_recall (line 206) | def compute_prec_recall(
  class ManifoldEstimator (line 217) | class ManifoldEstimator:
    method __init__ (line 224) | def __init__(
    method warmup (line 253) | def warmup(self):
    method manifold_radii (line 260) | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
    method evaluate (line 295) | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_featu...
    method evaluate_pr (line 337) | def evaluate_pr(
  class DistanceBlock (line 374) | class DistanceBlock:
    method __init__ (line 381) | def __init__(self, session):
    method pairwise_distances (line 405) | def pairwise_distances(self, U, V):
    method less_thans (line 414) | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
  function _batch_pairwise_distances (line 426) | def _batch_pairwise_distances(U, V):
  class NpzArrayReader (line 445) | class NpzArrayReader(ABC):
    method read_batch (line 447) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 451) | def remaining(self) -> int:
    method read_batches (line 454) | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
  class BatchIterator (line 467) | class BatchIterator:
    method __init__ (line 468) | def __init__(self, gen_fn, length):
    method __len__ (line 472) | def __len__(self):
    method __iter__ (line 475) | def __iter__(self):
  class StreamingNpzArrayReader (line 479) | class StreamingNpzArrayReader(NpzArrayReader):
    method __init__ (line 480) | def __init__(self, arr_f, shape, dtype):
    method read_batch (line 486) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 501) | def remaining(self) -> int:
  class MemoryNpzArrayReader (line 505) | class MemoryNpzArrayReader(NpzArrayReader):
    method __init__ (line 506) | def __init__(self, arr):
    method load (line 511) | def load(cls, path: str, arr_name: str):
    method read_batch (line 516) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 524) | def remaining(self) -> int:
  function open_npz_array (line 529) | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
  function _read_bytes (line 546) | def _read_bytes(fp, size, error_template="ran out of data"):
  function _open_npy_file (line 576) | def _open_npy_file(path: str, arr_name: str):
  function _download_inception_model (line 585) | def _download_inception_model():
  function _create_feature_graph (line 598) | def _create_feature_graph(input_batch):
  function _create_softmax_graph (line 615) | def _create_softmax_graph(input_batch):
  function _update_shapes (line 629) | def _update_shapes(pool3):
  function _numpy_partition (line 648) | def _numpy_partition(arr, kth, **kwargs):

FILE: evaluations/canny_f1score.py
  class ImageDataset (line 18) | class ImageDataset(Dataset):
    method __init__ (line 19) | def __init__(self, img_dir, label_dir):
    method __len__ (line 24) | def __len__(self):
    method __getitem__ (line 27) | def __getitem__(self, idx):

FILE: evaluations/clean_fid.py
  function main (line 4) | def main(args):

FILE: evaluations/cocostuff_mIoU.py
  function main (line 9) | def main():

FILE: evaluations/depth_rmse.py
  class ImageDataset (line 15) | class ImageDataset(Dataset):
    method __init__ (line 16) | def __init__(self, img_dir, label_dir):
    method __len__ (line 21) | def __len__(self):
    method __getitem__ (line 24) | def __getitem__(self, idx):

FILE: evaluations/hed_ssim.py
  class ImageDataset (line 17) | class ImageDataset(Dataset):
    method __init__ (line 18) | def __init__(self, img_dir, label_dir):
    method __len__ (line 23) | def __len__(self):
    method __getitem__ (line 26) | def __getitem__(self, idx):

FILE: evaluations/lineart_ssim.py
  class ImageDataset (line 18) | class ImageDataset(Dataset):
    method __init__ (line 19) | def __init__(self, img_dir, label_dir):
    method __len__ (line 24) | def __len__(self):
    method __getitem__ (line 27) | def __getitem__(self, idx):

FILE: evaluations/t2i/evaluation.py
  class CenterCropLongEdge (line 31) | class CenterCropLongEdge(object):
    method __call__ (line 37) | def __call__(self, img):
    method __repr__ (line 40) | def __repr__(self):
  class EvalDataset (line 44) | class EvalDataset(Dataset):
    method __init__ (line 45) | def __init__(self,
    method natural_sort (line 77) | def natural_sort(self, l):
    method load_dataset (line 82) | def load_dataset(self):
    method __len__ (line 100) | def __len__(self):
    method __getitem__ (line 104) | def __getitem__(self, index):
  function tensor2pil (line 117) | def tensor2pil(image: torch.Tensor):
  function compute_clip_score (line 130) | def compute_clip_score(
  function compute_fid (line 180) | def compute_fid(fake_dir: Path, gt_dir: Path,
  function evaluate_model (line 207) | def evaluate_model(opt):

FILE: language/extract_t5_feature.py
  class CustomDataset (line 23) | class CustomDataset(Dataset):
    method __init__ (line 24) | def __init__(self, lst_dir, start, end, caption_key, trunc_caption=Fal...
    method __len__ (line 41) | def __len__(self):
    method __getitem__ (line 44) | def __getitem__(self, index):
  function main (line 53) | def main(args):

FILE: language/t5.py
  class T5Embedder (line 15) | class T5Embedder:
    method __init__ (line 19) | def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=F...
    method get_text_embeddings (line 58) | def get_text_embeddings(self, texts):
    method text_preprocessing (line 81) | def text_preprocessing(self, text):
    method basic_clean (line 91) | def basic_clean(text):
    method clean_caption (line 96) | def clean_caption(self, caption):

FILE: tokenizer/consistencydecoder/cd_demo.py
  function main (line 9) | def main(args):

FILE: tokenizer/consistencydecoder/reconstruction_cd_ddp.py
  class SingleFolderDataset (line 22) | class SingleFolderDataset(Dataset):
    method __init__ (line 23) | def __init__(self, directory, transform=None):
    method __len__ (line 30) | def __len__(self):
    method __getitem__ (line 33) | def __getitem__(self, idx):
  function create_npz_from_sample_folder (line 41) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function center_crop_arr (line 60) | def center_crop_arr(pil_image, image_size):
  function main (line 81) | def main(args):

FILE: tokenizer/tokenizer_image/discriminator.py
  class PatchGANDiscriminator (line 17) | class PatchGANDiscriminator(nn.Module):
    method __init__ (line 21) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
    method _init_weights (line 67) | def _init_weights(self, module):
    method forward (line 74) | def forward(self, input):
  class ActNorm (line 79) | class ActNorm(nn.Module):
    method __init__ (line 80) | def __init__(self, num_features, logdet=False, affine=True,
    method initialize (line 91) | def initialize(self, input):
    method forward (line 112) | def forward(self, input, reverse=False):
    method reverse (line 140) | def reverse(self, output):
  class StyleGANDiscriminator (line 168) | class StyleGANDiscriminator(nn.Module):
    method __init__ (line 169) | def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=...
    method forward (line 203) | def forward(self, x):
  class DiscriminatorBlock (line 212) | class DiscriminatorBlock(nn.Module):
    method __init__ (line 213) | def __init__(self, input_channels, filters, downsample=True):
    method forward (line 229) | def forward(self, x):
  class Blur (line 238) | class Blur(nn.Module):
    method __init__ (line 239) | def __init__(self):
    method forward (line 244) | def forward(self, x):
  function leaky_relu (line 250) | def leaky_relu(p=0.2):
  function exists (line 254) | def exists(val):

FILE: tokenizer/tokenizer_image/discriminator_patchgan.py
  class NLayerDiscriminator (line 8) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 12) | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
    method _init_weights (line 58) | def _init_weights(self, module):
    method forward (line 65) | def forward(self, input):
  class ActNorm (line 70) | class ActNorm(nn.Module):
    method __init__ (line 71) | def __init__(self, num_features, logdet=False, affine=True,
    method initialize (line 82) | def initialize(self, input):
    method forward (line 103) | def forward(self, input, reverse=False):
    method reverse (line 131) | def reverse(self, output):

FILE: tokenizer/tokenizer_image/discriminator_stylegan.py
  class Discriminator (line 13) | class Discriminator(nn.Module):
    method __init__ (line 14) | def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=...
    method forward (line 48) | def forward(self, x):
  class DiscriminatorBlock (line 57) | class DiscriminatorBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, input_channels, filters, downsample=True):
    method forward (line 74) | def forward(self, x):
  class Blur (line 84) | class Blur(nn.Module):
    method __init__ (line 85) | def __init__(self):
    method forward (line 90) | def forward(self, x):
  function leaky_relu (line 96) | def leaky_relu(p=0.2):
  function exists (line 100) | def exists(val):

FILE: tokenizer/tokenizer_image/lpips.py
  function download (line 24) | def download(url, local_path, chunk_size=1024):
  function md5_hash (line 36) | def md5_hash(path):
  function get_ckpt_path (line 42) | def get_ckpt_path(name, root, check=False):
  class LPIPS (line 53) | class LPIPS(nn.Module):
    method __init__ (line 55) | def __init__(self, use_dropout=True):
    method load_from_pretrained (line 69) | def load_from_pretrained(self, name="vgg_lpips"):
    method from_pretrained (line 75) | def from_pretrained(cls, name="vgg_lpips"):
    method forward (line 83) | def forward(self, input, target):
  class ScalingLayer (line 99) | class ScalingLayer(nn.Module):
    method __init__ (line 100) | def __init__(self):
    method forward (line 105) | def forward(self, inp):
  class NetLinLayer (line 109) | class NetLinLayer(nn.Module):
    method __init__ (line 111) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
  class vgg16 (line 118) | class vgg16(torch.nn.Module):
    method __init__ (line 119) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 142) | def forward(self, X):
  function normalize_tensor (line 158) | def normalize_tensor(x,eps=1e-10):
  function spatial_average (line 163) | def spatial_average(x, keepdim=True):

FILE: tokenizer/tokenizer_image/reconstruction_vq_ddp.py
  function create_npz_from_sample_folder (line 24) | def create_npz_from_sample_folder(sample_dir, num=50000):
  function main (line 42) | def main(args):

FILE: tokenizer/tokenizer_image/vq_demo.py
  function main (line 13) | def main(args):

FILE: tokenizer/tokenizer_image/vq_loss.py
  function hinge_d_loss (line 14) | def hinge_d_loss(logits_real, logits_fake):
  function vanilla_d_loss (line 21) | def vanilla_d_loss(logits_real, logits_fake):
  function non_saturating_d_loss (line 28) | def non_saturating_d_loss(logits_real, logits_fake):
  function hinge_gen_loss (line 35) | def hinge_gen_loss(logit_fake):
  function non_saturating_gen_loss (line 39) | def non_saturating_gen_loss(logit_fake):
  function adopt_weight (line 43) | def adopt_weight(weight, global_step, threshold=0, value=0.):
  class VQLoss (line 49) | class VQLoss(nn.Module):
    method __init__ (line 50) | def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_ty...
    method calculate_adaptive_weight (line 109) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
    method forward (line 117) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_id...

FILE: tokenizer/tokenizer_image/vq_model.py
  class ModelArgs (line 13) | class ModelArgs:
  class VQModel (line 28) | class VQModel(nn.Module):
    method __init__ (line 29) | def __init__(self, config: ModelArgs):
    method encode (line 41) | def encode(self, x):
    method decode (line 48) | def decode(self, quant):
    method decode_code (line 53) | def decode_code(self, code_b, shape=None, channel_first=True):
    method forward (line 58) | def forward(self, input):
  class Encoder (line 65) | class Encoder(nn.Module):
    method __init__ (line 66) | def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res...
    method forward (line 106) | def forward(self, x):
  class Decoder (line 129) | class Decoder(nn.Module):
    method __init__ (line 130) | def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_re...
    method last_layer (line 171) | def last_layer(self):
    method forward (line 174) | def forward(self, z):
  class VectorQuantizer (line 198) | class VectorQuantizer(nn.Module):
    method __init__ (line 199) | def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show...
    method forward (line 216) | def forward(self, z):
    method get_codebook_entry (line 262) | def get_codebook_entry(self, indices, shape=None, channel_first=True):
  class ResnetBlock (line 280) | class ResnetBlock(nn.Module):
    method __init__ (line 281) | def __init__(self, in_channels, out_channels=None, conv_shortcut=False...
    method forward (line 300) | def forward(self, x):
  class AttnBlock (line 318) | class AttnBlock(nn.Module):
    method __init__ (line 319) | def __init__(self, in_channels, norm_type='group'):
    method forward (line 328) | def forward(self, x):
  function nonlinearity (line 355) | def nonlinearity(x):
  function Normalize (line 360) | def Normalize(in_channels, norm_type='group'):
  class Upsample (line 368) | class Upsample(nn.Module):
    method __init__ (line 369) | def __init__(self, in_channels, with_conv):
    method forward (line 375) | def forward(self, x):
  class Downsample (line 382) | class Downsample(nn.Module):
    method __init__ (line 383) | def __init__(self, in_channels, with_conv):
    method forward (line 390) | def forward(self, x):
  function compute_entropy_loss (line 400) | def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
  function VQ_8 (line 419) | def VQ_8(**kwargs):
  function VQ_16 (line 422) | def VQ_16(**kwargs):

FILE: tokenizer/tokenizer_image/vq_model_hf.py
  class VQModelHF (line 5) | class VQModelHF(VQModel, PyTorchModelHubMixin, repo_url="https://github....
  function VQ_8 (line 11) | def VQ_8(**kwargs):
  function VQ_16 (line 14) | def VQ_16(**kwargs):

FILE: tokenizer/tokenizer_image/vq_train.py
  function main (line 37) | def main(args):

FILE: tokenizer/vae/reconstruction_vae_ddp.py
  class SingleFolderDataset (line 22) | class SingleFolderDataset(Dataset):
    method __init__ (line 23) | def __init__(self, directory, transform=None):
    method __len__ (line 30) | def __len__(self):
    method __getitem__ (line 33) | def __getitem__(self, idx):
  function create_npz_from_sample_folder (line 41) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function center_crop_arr (line 60) | def center_crop_arr(pil_image, image_size):
  function main (line 81) | def main(args):

FILE: tokenizer/vae/sd_vae_demo.py
  function main (line 9) | def main(args):

FILE: tokenizer/validation/val_ddp.py
  class SingleFolderDataset (line 17) | class SingleFolderDataset(Dataset):
    method __init__ (line 18) | def __init__(self, directory, transform=None):
    method __len__ (line 25) | def __len__(self):
    method __getitem__ (line 28) | def __getitem__(self, idx):
  function create_npz_from_sample_folder (line 36) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function center_crop_arr (line 55) | def center_crop_arr(pil_image, image_size):
  function main (line 76) | def main(args):

FILE: tokenizer/vqgan/layer.py
  function nonlinearity (line 8) | def nonlinearity(x):
  function Normalize (line 13) | def Normalize(in_channels):
  class Upsample (line 17) | class Upsample(nn.Module):
    method __init__ (line 18) | def __init__(self, in_channels, with_conv):
    method forward (line 28) | def forward(self, x):
  class Downsample (line 35) | class Downsample(nn.Module):
    method __init__ (line 36) | def __init__(self, in_channels, with_conv):
    method forward (line 47) | def forward(self, x):
  class ResnetBlock (line 57) | class ResnetBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 96) | def forward(self, x, temb):
  class AttnBlock (line 119) | class AttnBlock(nn.Module):
    method __init__ (line 120) | def __init__(self, in_channels):
    method forward (line 147) | def forward(self, x):
  class Encoder (line 175) | class Encoder(nn.Module):
    method __init__ (line 176) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 239) | def forward(self, x):
  class Decoder (line 269) | class Decoder(nn.Module):
    method __init__ (line 270) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 339) | def forward(self, z):

FILE: tokenizer/vqgan/model.py
  class VQModel (line 24) | class VQModel(nn.Module):
    method __init__ (line 25) | def __init__(self,
    method init_from_ckpt (line 55) | def init_from_ckpt(self, path, ignore_keys=list(), logging=True):
    method encode (line 69) | def encode(self, x):
    method decode (line 75) | def decode(self, quant):
    method decode_code (line 80) | def decode_code(self, code_b, shape, channel_first=True):
    method forward (line 85) | def forward(self, input):

FILE: tokenizer/vqgan/quantize.py
  class VectorQuantizer (line 9) | class VectorQuantizer(nn.Module):
    method __init__ (line 25) | def __init__(self, n_e, e_dim, beta):
    method forward (line 34) | def forward(self, z):
    method get_codebook_entry (line 92) | def get_codebook_entry(self, indices, shape):
  class VectorQuantizer2 (line 110) | class VectorQuantizer2(nn.Module):
    method __init__ (line 118) | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
    method remap_to_used (line 144) | def remap_to_used(self, inds):
    method unmap_to_all (line 158) | def unmap_to_all(self, inds):
    method forward (line 168) | def forward(self, z, temp=None, rescale_logits=False, return_logits=Fa...
    method get_codebook_entry (line 211) | def get_codebook_entry(self, indices, shape, channel_first=True):

FILE: tokenizer/vqgan/reconstruction_vqgan_ddp.py
  class SingleFolderDataset (line 24) | class SingleFolderDataset(Dataset):
    method __init__ (line 25) | def __init__(self, directory, transform=None):
    method __len__ (line 32) | def __len__(self):
    method __getitem__ (line 35) | def __getitem__(self, idx):
  function create_npz_from_sample_folder (line 43) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function center_crop_arr (line 62) | def center_crop_arr(pil_image, image_size):
  function main (line 83) | def main(args):

FILE: tokenizer/vqgan/taming_vqgan_demo.py
  function main (line 17) | def main(args):

FILE: tools/check_image_codes.py
  function main (line 9) | def main(args):

FILE: tools/draw_figure.py
  function fid_scaling_law_no_cfg (line 6) | def fid_scaling_law_no_cfg():
  function fid_scaling_law_cfg (line 43) | def fid_scaling_law_cfg():
  function sample_topk (line 80) | def sample_topk():
  function sample_cfg (line 108) | def sample_cfg():

FILE: tools/openimage_json.py
  function check_image (line 11) | def check_image(image_path):
  function check_image_path (line 20) | def check_image_path(image_info):
  function load_image_path (line 29) | def load_image_path(image_info):
  function main (line 44) | def main(args):

FILE: tools/push_gpt_to_hf.py
  function main (line 13) | def main(args):

FILE: tools/push_vae_to_hf.py
  function load_model (line 17) | def load_model(args):

FILE: utils/data.py
  function center_crop_arr (line 4) | def center_crop_arr(pil_image, image_size):

FILE: utils/deepspeed.py
  function create_deepspeed_config (line 1) | def create_deepspeed_config(args):

FILE: utils/distributed.py
  function setup_for_distributed (line 6) | def setup_for_distributed(is_master):
  function init_distributed_mode (line 20) | def init_distributed_mode(args):

FILE: utils/drop_path.py
  function drop_path (line 4) | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by...
  class DropPath (line 24) | class DropPath(torch.nn.Module):
    method __init__ (line 27) | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
    method forward (line 32) | def forward(self, x):
    method extra_repr (line 35) | def extra_repr(self):

FILE: utils/ema.py
  function update_ema (line 5) | def update_ema(ema_model, model, decay=0.9999):
  function requires_grad (line 17) | def requires_grad(model, flag=True):

FILE: utils/logger.py
  function create_logger (line 4) | def create_logger(logging_dir):

FILE: utils/video.py
  function shift_dim (line 8) | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
  function view_range (line 38) | def view_range(x, i, j, shape):
  function tensor_slice (line 57) | def tensor_slice(x, begin, size):
  function save_video_grid (line 67) | def save_video_grid(video, fname, nrow=None, fps=5):
  function save_gif_grid (line 89) | def save_gif_grid(video, file_name, nrow=None, fps=5):
Condensed preview — 184 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,840K chars).
[
  {
    "path": ".gitignore",
    "chars": 3139,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 15232,
    "preview": "<div align =\"center\">\n<img src=\"./assets/logo.jpeg\" width=\"20%\">\n<h1> ControlAR </h1>\n<h3> Controllable Image Generation"
  },
  {
    "path": "autoregressive/models/README.md",
    "chars": 233,
    "preview": "Download the vit weight first \n\nViT-small: https://huggingface.co/WinKawaks/vit-small-patch16-224 \\\nDinov2-small: https:"
  },
  {
    "path": "autoregressive/models/dinov2_adapter.py",
    "chars": 1391,
    "preview": "from transformers import AutoImageProcessor, AutoModel\nfrom PIL import Image\nimport requests\nimport torch\nimport torch.n"
  },
  {
    "path": "autoregressive/models/generate.py",
    "chars": 9546,
    "preview": "# Modified from:\n#   gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py\n#   DiT:      https://gith"
  },
  {
    "path": "autoregressive/models/gpt.py",
    "chars": 24127,
    "preview": "# Modified from:\n#   VQGAN:    https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/min"
  },
  {
    "path": "autoregressive/models/gpt_t2i.py",
    "chars": 24866,
    "preview": "# Modified from:\n#   VQGAN:    https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/min"
  },
  {
    "path": "autoregressive/models/vit_adapter.py",
    "chars": 770,
    "preview": "from transformers import AutoImageProcessor, AutoModel\nfrom PIL import Image\nimport requests\nimport torch\nimport torch.n"
  },
  {
    "path": "autoregressive/sample/sample_c2i.py",
    "chars": 7015,
    "preview": "# Modified from:\n#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample.py\nimport torch\ntorch.backends.cuda."
  },
  {
    "path": "autoregressive/sample/sample_c2i_ddp.py",
    "chars": 8788,
    "preview": "# Modified from:\n#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py\nimport torch\ntorch.backends.c"
  },
  {
    "path": "autoregressive/sample/sample_t2i.py",
    "chars": 11153,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\ntorch.set_float32_matmu"
  },
  {
    "path": "autoregressive/sample/sample_t2i_MR.py",
    "chars": 11656,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\ntorch.set_float32_matmu"
  },
  {
    "path": "autoregressive/sample/sample_t2i_ddp.py",
    "chars": 10676,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\ntorch.set_float32_matmu"
  },
  {
    "path": "autoregressive/serve/README.md",
    "chars": 2474,
    "preview": "## serving by vLLM\n\n### Install\n```\npip install vllm==0.4.1\n```\n\n### Comparison (A100)\n\nMethod | params | baseline(s) | "
  },
  {
    "path": "autoregressive/serve/fake_json/GPT-3B.json",
    "chars": 653,
    "preview": "{\n  \"_name_or_path\": \"facebook/opt-125m\",\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"relu\",\n  \"architectures"
  },
  {
    "path": "autoregressive/serve/fake_json/GPT-B.json",
    "chars": 652,
    "preview": "{\n  \"_name_or_path\": \"facebook/opt-125m\",\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"relu\",\n  \"architectures"
  },
  {
    "path": "autoregressive/serve/fake_json/GPT-L.json",
    "chars": 653,
    "preview": "{\n  \"_name_or_path\": \"facebook/opt-125m\",\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"relu\",\n  \"architectures"
  },
  {
    "path": "autoregressive/serve/fake_json/GPT-XL.json",
    "chars": 653,
    "preview": "{\n  \"_name_or_path\": \"facebook/opt-125m\",\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"relu\",\n  \"architectures"
  },
  {
    "path": "autoregressive/serve/fake_json/GPT-XXL.json",
    "chars": 653,
    "preview": "{\n  \"_name_or_path\": \"facebook/opt-125m\",\n  \"activation_dropout\": 0.0,\n  \"activation_function\": \"relu\",\n  \"architectures"
  },
  {
    "path": "autoregressive/serve/gpt_model.py",
    "chars": 17073,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional, List\n\nimport torch\nimport torch.nn as nn\n\nfrom vllm.model"
  },
  {
    "path": "autoregressive/serve/gpu_executor.py",
    "chars": 7809,
    "preview": "from typing import Dict, List, Set, Tuple, Optional, Set\nimport argparse\n\nfrom vllm.config import (CacheConfig, DeviceCo"
  },
  {
    "path": "autoregressive/serve/llm.py",
    "chars": 12277,
    "preview": "# Modified from:\n#   vLLM:    https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py\nfrom typing import "
  },
  {
    "path": "autoregressive/serve/llm_engine.py",
    "chars": 28754,
    "preview": "# Modified from:\n#   vLLM:    https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py\nimport time\nfrom "
  },
  {
    "path": "autoregressive/serve/model_runner.py",
    "chars": 52906,
    "preview": "import contextlib\nimport time\nfrom enum import IntEnum\nfrom typing import Dict, List, NamedTuple, Optional, Set, Tuple\n\n"
  },
  {
    "path": "autoregressive/serve/sample_c2i.py",
    "chars": 4736,
    "preview": "import time\nimport argparse\nimport torch\nfrom torchvision.utils import save_image\nimport sys\nsys.path.append('/data/zong"
  },
  {
    "path": "autoregressive/serve/sampler.py",
    "chars": 37923,
    "preview": "\"\"\"A layer that samples the next tokens from the model's outputs.\"\"\"\nimport itertools\nfrom typing import Dict, List, Opt"
  },
  {
    "path": "autoregressive/serve/worker.py",
    "chars": 14776,
    "preview": "\"\"\"A GPU worker class.\"\"\"\nimport gc\nimport os\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nimport torch\nimp"
  },
  {
    "path": "autoregressive/test/metric.py",
    "chars": 2848,
    "preview": "import numpy as np\nfrom skimage.metrics import structural_similarity as ssim\nfrom sklearn.metrics import f1_score\nfrom t"
  },
  {
    "path": "autoregressive/test/test_c2i.py",
    "chars": 11518,
    "preview": "# Modified from:\n#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample.py\nimport torch\ntorch.backends.cuda."
  },
  {
    "path": "autoregressive/test/test_ssim.py",
    "chars": 666,
    "preview": "\nimport torch\nfrom torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure\nimport torchvision.transforms as"
  },
  {
    "path": "autoregressive/test/test_t2i.py",
    "chars": 13097,
    "preview": "# Modified from:\n#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample.py\nimport warnings\nwarnings.filterwa"
  },
  {
    "path": "autoregressive/train/extract_codes_c2i.py",
    "chars": 5728,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\n# import os\n# os.e"
  },
  {
    "path": "autoregressive/train/extract_codes_t2i.py",
    "chars": 5607,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\nimport torch\ntorch"
  },
  {
    "path": "autoregressive/train/extract_file_ade.py",
    "chars": 9628,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\nimport os\n# os.env"
  },
  {
    "path": "autoregressive/train/extract_file_cocostuff.py",
    "chars": 9580,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\nimport os\n# os.env"
  },
  {
    "path": "autoregressive/train/extract_file_imagenet.py",
    "chars": 7642,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\nimport os\n# os.env"
  },
  {
    "path": "autoregressive/train/extract_file_multigen.py",
    "chars": 9210,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py\n\nimport warnings\nw"
  },
  {
    "path": "autoregressive/train/train_c2i.py",
    "chars": 14777,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py\n#   nanoGPT: https://github.c"
  },
  {
    "path": "autoregressive/train/train_c2i_canny.py",
    "chars": 15944,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py\n#   nanoGPT: https://github.c"
  },
  {
    "path": "autoregressive/train/train_c2i_depth.py",
    "chars": 16177,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py\n#   nanoGPT: https://github.c"
  },
  {
    "path": "autoregressive/train/train_c2i_fsdp.py",
    "chars": 18123,
    "preview": "# Modified from:\n#   Large-DiT: https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py\nimp"
  },
  {
    "path": "autoregressive/train/train_t2i.py",
    "chars": 13133,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_canny.py",
    "chars": 14196,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_depth.py",
    "chars": 14362,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_depth_multiscale.py",
    "chars": 17803,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_hed.py",
    "chars": 14567,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_hed_multiscale.py",
    "chars": 17393,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_lineart.py",
    "chars": 14601,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_lineart_multiscale.py",
    "chars": 17560,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_seg.py",
    "chars": 14342,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "autoregressive/train/train_t2i_seg_multiscale.py",
    "chars": 17293,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT\n#   nanoGPT: https://github.com/karpathy/nanoGPT"
  },
  {
    "path": "condition/README.md",
    "chars": 817,
    "preview": "Prepare the preprocessing model\n\nHed: https://huggingface.co/lllyasviel/Annotators/blob/main/ControlNetHED.pth\\\nLineart:"
  },
  {
    "path": "condition/canny.py",
    "chars": 788,
    "preview": "import cv2\nimport torch\nimport numpy as np\n\n\nclass CannyDetector:\n    def __call__(self, img, low_threshold=100, high_th"
  },
  {
    "path": "condition/depth.py",
    "chars": 1691,
    "preview": "from controlnet_aux import LineartDetector\nimport torch\nimport cv2\nimport numpy as np\nfrom transformers import DPTImageP"
  },
  {
    "path": "condition/hed.py",
    "chars": 5535,
    "preview": "# This is an improved version and model of HED edge detection with Apache License, Version 2.0.\n# Please use this implem"
  },
  {
    "path": "condition/lineart.py",
    "chars": 3378,
    "preview": "from controlnet_aux import LineartDetector\nimport torch\nimport cv2\nimport numpy as np\nimport torch.nn as nn\n\n\nnorm_layer"
  },
  {
    "path": "condition/midas/depth.py",
    "chars": 7585,
    "preview": "# Midas Depth Estimation\n# From https://github.com/isl-org/MiDaS\n# MIT LICENSE\n\nimport cv2\nimport numpy as np\nimport tor"
  },
  {
    "path": "condition/midas/midas/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "condition/midas/midas/base_model.py",
    "chars": 366,
    "preview": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n\n        Arg"
  },
  {
    "path": "condition/midas/midas/blocks.py",
    "chars": 9240,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .vit import (\n    _make_pretrained_vitb_rn50_384,\n    _make_pretrained_vitl16_3"
  },
  {
    "path": "condition/midas/midas/dpt_depth.py",
    "chars": 3152,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .base_model import BaseModel\nfrom .blocks impor"
  },
  {
    "path": "condition/midas/midas/midas_net.py",
    "chars": 2708,
    "preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
  },
  {
    "path": "condition/midas/midas/midas_net_custom.py",
    "chars": 5207,
    "preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
  },
  {
    "path": "condition/midas/midas/transforms.py",
    "chars": 7868,
    "preview": "import numpy as np\nimport cv2\nimport math\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):"
  },
  {
    "path": "condition/midas/midas/vit.py",
    "chars": 14624,
    "preview": "import torch\nimport torch.nn as nn\nimport timm\nimport types\nimport math\nimport torch.nn.functional as F\n\n\nclass Slice(nn"
  },
  {
    "path": "condition/utils.py",
    "chars": 979,
    "preview": "import numpy as np\nimport cv2\nimport os\n\n\nannotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')\n\n\ndef "
  },
  {
    "path": "create_npz.py",
    "chars": 968,
    "preview": "from tqdm import tqdm\nimport os\nfrom PIL import Image\nimport numpy as np\nimport argparse\n\n\ndef create_npz_from_sample_fo"
  },
  {
    "path": "dataset/augmentation.py",
    "chars": 2077,
    "preview": "# from https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_d"
  },
  {
    "path": "dataset/build.py",
    "chars": 1155,
    "preview": "from dataset.imagenet import build_imagenet, build_imagenet_code\nfrom dataset.coco import build_coco\nfrom dataset.openim"
  },
  {
    "path": "dataset/coco.py",
    "chars": 853,
    "preview": "import os\nimport torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\n\nclass SingleFolderDataset(Dataset):\n"
  },
  {
    "path": "dataset/imagenet.py",
    "chars": 5938,
    "preview": "import torch\nimport numpy as np\nimport os\nfrom torch.utils.data import Dataset\nfrom torchvision.datasets import ImageFol"
  },
  {
    "path": "dataset/openimage.py",
    "chars": 1309,
    "preview": "import os\nimport json\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nfrom torch.utils.data import Dataset\n\n\nclas"
  },
  {
    "path": "dataset/pexels.py",
    "chars": 140,
    "preview": "from torchvision.datasets import ImageFolder\n\ndef build_pexels(args, transform):\n    return ImageFolder(args.data_path, "
  },
  {
    "path": "dataset/t2i.py",
    "chars": 6076,
    "preview": "import os\nimport json\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import Dataset\nfrom PIL import Image \n\n\ncla"
  },
  {
    "path": "dataset/t2i_control.py",
    "chars": 7857,
    "preview": "from PIL import PngImagePlugin\nMaximumDecompressedSize = 1024\nMegaByte = 2**20\nPngImagePlugin.MAX_TEXT_CHUNK = MaximumDe"
  },
  {
    "path": "dataset/utils.py",
    "chars": 11952,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nimport torchvision.transforms.functional as F\n\nfrom PIL import Ima"
  },
  {
    "path": "demo/app.py",
    "chars": 1493,
    "preview": "import os\nimport gradio as gr\nfrom .model import Model\nfrom huggingface_hub import hf_hub_download\nfrom app_canny import"
  },
  {
    "path": "demo/app_depth.py",
    "chars": 4846,
    "preview": "import gradio as gr\r\nimport random\r\n\r\n\r\ndef randomize_seed_fn(seed: int, randomize_seed: bool) -> int:\r\n    if randomize"
  },
  {
    "path": "demo/app_edge.py",
    "chars": 5460,
    "preview": "import gradio as gr\r\nimport random\r\n\r\n\r\ndef randomize_seed_fn(seed: int, randomize_seed: bool) -> int:\r\n    if randomize"
  },
  {
    "path": "demo/model.py",
    "chars": 10478,
    "preview": "import gc\nimport spaces\nfrom safetensors.torch import load_file\nfrom autoregressive.models.gpt_t2i import GPT_models\nfro"
  },
  {
    "path": "evaluations/ade20k_mIoU.py",
    "chars": 2637,
    "preview": "import os\nimport numpy as np\nfrom mmseg.apis import init_model, inference_model, show_result_pyplot#, inference_segmento"
  },
  {
    "path": "evaluations/c2i/README.md",
    "chars": 7477,
    "preview": "# Evaluations from [OpenAI](https://github.com/openai/guided-diffusion/tree/main/evaluations)\n\nTo compare different gene"
  },
  {
    "path": "evaluations/c2i/evaluator.py",
    "chars": 25431,
    "preview": "import argparse\nimport io\nimport os\nimport random\nimport warnings\nimport zipfile\nfrom abc import ABC, abstractmethod\nfro"
  },
  {
    "path": "evaluations/canny_f1score.py",
    "chars": 2204,
    "preview": "import matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\n"
  },
  {
    "path": "evaluations/clean_fid.py",
    "chars": 507,
    "preview": "from cleanfid import fid\nimport argparse\n\ndef main(args):\n    real_data_path = args.val_images\n    gen_data_path = args."
  },
  {
    "path": "evaluations/cocostuff_mIoU.py",
    "chars": 2706,
    "preview": "import os\nimport numpy as np\nfrom mmseg.apis import init_model, inference_model, show_result_pyplot#, inference_segmento"
  },
  {
    "path": "evaluations/depth_rmse.py",
    "chars": 2256,
    "preview": "import matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\n"
  },
  {
    "path": "evaluations/hed_ssim.py",
    "chars": 1927,
    "preview": "import matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\n"
  },
  {
    "path": "evaluations/lineart_ssim.py",
    "chars": 2052,
    "preview": "import matplotlib.pyplot as plt\nfrom tqdm import tqdm\nfrom transformers import DPTImageProcessor, DPTForDepthEstimation\n"
  },
  {
    "path": "evaluations/t2i/PartiPrompts.tsv",
    "chars": 123072,
    "preview": "Prompt\tCategory\tChallenge\tNote\nbond\tAbstract\tBasic\tBiology-inspired concepts with multiple meanings\nelement\tAbstract\tBas"
  },
  {
    "path": "evaluations/t2i/README.md",
    "chars": 453,
    "preview": "# Evaluations from [GigaGAN](https://github.com/mingukkang/GigaGAN/tree/main/evaluation)\n\n```\npip install git+https://gi"
  },
  {
    "path": "evaluations/t2i/coco_captions.csv",
    "chars": 1595024,
    "preview": "Prompt\nThis wire metal rack holds several pairs of shoes and sandals\nA motorcycle parked in a parking space next to anot"
  },
  {
    "path": "evaluations/t2i/evaluation.py",
    "chars": 10054,
    "preview": "# Modified from:\n#   GigaGAN: https://github.com/mingukkang/GigaGAN\nimport os\nimport torch\nimport numpy as np\nimport re\n"
  },
  {
    "path": "language/README.md",
    "chars": 375,
    "preview": "## Language models for text-conditional image generation\n\n### Requirements\n```\npip install ftfy\npip install transformers"
  },
  {
    "path": "language/extract_t5_feature.py",
    "chars": 5041,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.distribute"
  },
  {
    "path": "language/t5.py",
    "chars": 8699,
    "preview": "# Modified from:\n#   PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py\nimport os\nim"
  },
  {
    "path": "requirements.txt",
    "chars": 260,
    "preview": "torchvision==0.16.2\nopencv-python==4.9.0.80\nmatplotlib==3.9.0\nnumpy==1.26.4\neinops\ndatasets\ntensorflow==2.16.1\nscikit-le"
  },
  {
    "path": "scripts/autoregressive/extract_codes_c2i.sh",
    "chars": 147,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=2 --node_rank=0 \\\n--master_port=12335 \\\nautoregressive/train"
  },
  {
    "path": "scripts/autoregressive/extract_file_ade.sh",
    "chars": 146,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12336 \\\nautoregressive/train"
  },
  {
    "path": "scripts/autoregressive/extract_file_cocostuff.sh",
    "chars": 152,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12336 \\\nautoregressive/train"
  },
  {
    "path": "scripts/autoregressive/extract_file_imagenet.sh",
    "chars": 151,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12336 \\\nautoregressive/train"
  },
  {
    "path": "scripts/autoregressive/extract_file_multigen.sh",
    "chars": 167,
    "preview": "# !/bin/bash\n\n# sleep 21600\n\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12336 \\\nauto"
  },
  {
    "path": "scripts/autoregressive/sample_c2i.sh",
    "chars": 194,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_port=12346 \\\nautoregressive/sampl"
  },
  {
    "path": "scripts/autoregressive/sample_t2i_coco.sh",
    "chars": 271,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12346 \\\nautoregressive/sampl"
  },
  {
    "path": "scripts/autoregressive/sample_t2i_parti.sh",
    "chars": 271,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12347 \\\nautoregressive/sampl"
  },
  {
    "path": "scripts/autoregressive/test_c2i.sh",
    "chars": 139,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12349 \\\nautoregressive/test/"
  },
  {
    "path": "scripts/autoregressive/test_t2i.sh",
    "chars": 139,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12349 \\\nautoregressive/test/"
  },
  {
    "path": "scripts/autoregressive/train_c2i.sh",
    "chars": 163,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_addr=127.0.0.1 --master_port=1234"
  },
  {
    "path": "scripts/autoregressive/train_c2i_canny.sh",
    "chars": 183,
    "preview": "# !/bin/bash\n# sleep 43200\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_addr=127.0.0.1 --ma"
  },
  {
    "path": "scripts/autoregressive/train_c2i_depth.sh",
    "chars": 183,
    "preview": "# !/bin/bash\n# sleep 39600\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_addr=127.0.0.1 --ma"
  },
  {
    "path": "scripts/autoregressive/train_c2i_fsdp.sh",
    "chars": 207,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \\\n--master_addr"
  },
  {
    "path": "scripts/autoregressive/train_t2i_canny.sh",
    "chars": 409,
    "preview": "# !/bin/bash\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master"
  },
  {
    "path": "scripts/autoregressive/train_t2i_depth.sh",
    "chars": 423,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_depth_multiscale.sh",
    "chars": 471,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_hed.sh",
    "chars": 443,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_hed_multiscale.sh",
    "chars": 469,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_lineart.sh",
    "chars": 425,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_lineart_multiscale.sh",
    "chars": 473,
    "preview": "# !/bin/bash\n# sleep 36000\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_ran"
  },
  {
    "path": "scripts/autoregressive/train_t2i_seg.sh",
    "chars": 458,
    "preview": "# !/bin/bash\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master"
  },
  {
    "path": "scripts/autoregressive/train_t2i_seg_multiscale.sh",
    "chars": 524,
    "preview": "# !/bin/bash\nset -x\nexport TOKENIZERS_PARALLELISM=true\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master"
  },
  {
    "path": "scripts/autoregressive/train_t2i_stage1.sh",
    "chars": 374,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \\\n--master_addr"
  },
  {
    "path": "scripts/autoregressive/train_t2i_stage2.sh",
    "chars": 452,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \\\n--master_addr"
  },
  {
    "path": "scripts/language/extract_flan_t5_feat_laion_coco_stage1.sh",
    "chars": 242,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12337 \\\nlanguage/extract_t5_"
  },
  {
    "path": "scripts/language/extract_flan_t5_feat_stage2.sh",
    "chars": 231,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12337 \\\nlanguage/extract_t5_"
  },
  {
    "path": "scripts/language/extract_flan_t5_feat_trunc_stage2.sh",
    "chars": 255,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12337 \\\nlanguage/extract_t5_"
  },
  {
    "path": "scripts/tokenizer/reconstruction_consistency_decoder.sh",
    "chars": 160,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12344 \\\ntokenizer/consistenc"
  },
  {
    "path": "scripts/tokenizer/reconstruction_vae.sh",
    "chars": 146,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12344 \\\ntokenizer/vae/recons"
  },
  {
    "path": "scripts/tokenizer/reconstruction_vq.sh",
    "chars": 157,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=1 --node_rank=0 \\\n--master_port=12344 \\\ntokenizer/tokenizer_"
  },
  {
    "path": "scripts/tokenizer/reconstruction_vqgan.sh",
    "chars": 150,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=8 --node_rank=0 \\\n--master_port=12344 \\\ntokenizer/vqgan/reco"
  },
  {
    "path": "scripts/tokenizer/train_vq.sh",
    "chars": 166,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_addr=127.0.0.1 --master_port=1234"
  },
  {
    "path": "scripts/tokenizer/train_vq_finetune.sh",
    "chars": 291,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_addr=127.0.0.1 --master_port=1234"
  },
  {
    "path": "scripts/tokenizer/train_vq_finetune_continue.sh",
    "chars": 384,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \\\n--master_addr"
  },
  {
    "path": "scripts/tokenizer/val.sh",
    "chars": 138,
    "preview": "# !/bin/bash\nset -x\n\ntorchrun \\\n--nnodes=1 --nproc_per_node=4 --node_rank=0 \\\n--master_port=12343 \\\ntokenizer/validation"
  },
  {
    "path": "tokenizer/consistencydecoder/README.md",
    "chars": 189,
    "preview": "## Consistency Decoder from OpenAI\n\n### install\n```\npip install diffusers\npip install accelerate\n```\n\n### demo\n```\ncd ${"
  },
  {
    "path": "tokenizer/consistencydecoder/cd_demo.py",
    "chars": 1972,
    "preview": "import argparse\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom diffusers imp"
  },
  {
    "path": "tokenizer/consistencydecoder/reconstruction_cd_ddp.py",
    "chars": 8272,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.distribute"
  },
  {
    "path": "tokenizer/tokenizer_image/discriminator.py",
    "chars": 8693,
    "preview": "# Modified from:\n#   taming-transformers:  https://github.com/CompVis/taming-transformers\n#   stylegan2-pytorch:    http"
  },
  {
    "path": "tokenizer/tokenizer_image/discriminator_patchgan.py",
    "chars": 5277,
    "preview": "# Modified from:\n#   taming-transformers:  https://github.com/CompVis/taming-transformers\nimport functools\nimport torch\n"
  },
  {
    "path": "tokenizer/tokenizer_image/discriminator_stylegan.py",
    "chars": 3083,
    "preview": "# Modified from:\n#   stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/st"
  },
  {
    "path": "tokenizer/tokenizer_image/lpips.py",
    "chars": 6208,
    "preview": "\"\"\"Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models\"\"\"\n\nimport os, hashlib\nimpor"
  },
  {
    "path": "tokenizer/tokenizer_image/reconstruction_vq_ddp.py",
    "chars": 8666,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.nn.functio"
  },
  {
    "path": "tokenizer/tokenizer_image/vq_demo.py",
    "chars": 3312,
    "preview": "import torch\nimport torch.nn.functional as F\n\nimport os\nimport argparse\nimport numpy as np\nfrom PIL import Image\n\nfrom t"
  },
  {
    "path": "tokenizer/tokenizer_image/vq_loss.py",
    "chars": 7518,
    "preview": "# Modified from:\n#   taming-transformers:  https://github.com/CompVis/taming-transformers\n#   muse-maskgit-pytorch: http"
  },
  {
    "path": "tokenizer/tokenizer_image/vq_model.py",
    "chars": 16012,
    "preview": "# Modified from:\n#   taming-transformers: https://github.com/CompVis/taming-transformers\n#   maskgit: https://github.com"
  },
  {
    "path": "tokenizer/tokenizer_image/vq_model_hf.py",
    "chars": 828,
    "preview": "from huggingface_hub import PyTorchModelHubMixin\n\nfrom tokenizer.tokenizer_image.vq_model import ModelArgs, VQModel\n\ncla"
  },
  {
    "path": "tokenizer/tokenizer_image/vq_train.py",
    "chars": 16438,
    "preview": "# Modified from:\n#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py\n#   nanoGPT: https://github.c"
  },
  {
    "path": "tokenizer/vae/README.md",
    "chars": 179,
    "preview": "## VAE Models from Stable Diffusion\n\n### install\n```\npip install diffusers\npip install accelerate\n```\n\n### demo\n```\ncd $"
  },
  {
    "path": "tokenizer/vae/reconstruction_vae_ddp.py",
    "chars": 8288,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.distribute"
  },
  {
    "path": "tokenizer/vae/sd_vae_demo.py",
    "chars": 2042,
    "preview": "import argparse\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom diffusers.mod"
  },
  {
    "path": "tokenizer/validation/val_ddp.py",
    "chars": 6025,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.distribute"
  },
  {
    "path": "tokenizer/vqgan/README.md",
    "chars": 513,
    "preview": "## Pretrained VQVAE Models\n\n### install\n```\npip install omegaconf\npip install einops\n```\n* download all needed models fr"
  },
  {
    "path": "tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml",
    "chars": 645,
    "preview": "model:\n  base_learning_rate: 4.5e-06\n  target: taming.models.vqgan.VQModel\n  params:\n    embed_dim: 256\n    n_embed: 102"
  },
  {
    "path": "tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml",
    "chars": 692,
    "preview": "model:\n  base_learning_rate: 4.5e-06\n  target: taming.models.vqgan.VQModel\n  params:\n    embed_dim: 256\n    n_embed: 163"
  },
  {
    "path": "tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml",
    "chars": 314,
    "preview": "model:\n  params:\n    embed_dim: 4\n    n_embed: 16384\n    ddconfig:\n      double_z: false\n      z_channels: 4\n      resol"
  },
  {
    "path": "tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml",
    "chars": 312,
    "preview": "model:\n  params:\n    embed_dim: 4\n    n_embed: 256\n    ddconfig:\n      double_z: false\n      z_channels: 4\n      resolut"
  },
  {
    "path": "tokenizer/vqgan/layer.py",
    "chars": 13787,
    "preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\n\ndef no"
  },
  {
    "path": "tokenizer/vqgan/model.py",
    "chars": 3456,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tokenizer.vqgan.layer import Encoder, Decoder\nf"
  },
  {
    "path": "tokenizer/vqgan/quantize.py",
    "chars": 9217,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom torch import einsum\nfrom eino"
  },
  {
    "path": "tokenizer/vqgan/reconstruction_vqgan_ddp.py",
    "chars": 8430,
    "preview": "import torch\ntorch.backends.cuda.matmul.allow_tf32 = True\ntorch.backends.cudnn.allow_tf32 = True\nimport torch.distribute"
  },
  {
    "path": "tokenizer/vqgan/taming_vqgan_demo.py",
    "chars": 2468,
    "preview": "import argparse\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\nfrom omegaconf imp"
  },
  {
    "path": "tools/check_image_codes.py",
    "chars": 2232,
    "preview": "import argparse\nimport torch\nimport numpy as np\n\nfrom tokenizer.tokenizer_image.vq_model import VQ_models\nfrom torchvisi"
  },
  {
    "path": "tools/convert_pytorch_lightning_to_torch.py",
    "chars": 686,
    "preview": "import os\nimport torch\n\nMODEL_PATH = 'pretrained_models'\npt_lightnings = [\n    'vqgan_imagenet_f16_1024/ckpts/last.ckpt'"
  },
  {
    "path": "tools/draw_figure.py",
    "chars": 4993,
    "preview": "import matplotlib.pyplot as plt\nimport numpy as np\n\nfont_size = 14\n\ndef fid_scaling_law_no_cfg():\n    # data\n    steps ="
  },
  {
    "path": "tools/imagenet_en_cn.py",
    "chars": 38508,
    "preview": "IMAGENET_1K_CLASSES = {\n  0: 'tench, Tinca tinca [丁鲷]',\n  1: 'goldfish, Carassius auratus [金鱼]',\n  2: 'great white shark"
  },
  {
    "path": "tools/openimage_json.py",
    "chars": 2488,
    "preview": "import argparse\nimport os\nimport json\nfrom PIL import Image\nimport multiprocessing as mp\n\nimport warnings\nwarnings.filte"
  },
  {
    "path": "tools/push_gpt_to_hf.py",
    "chars": 3342,
    "preview": "# Modified from:\n#   DiT:  https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py\nimport torch\ntorch.backends.c"
  },
  {
    "path": "tools/push_vae_to_hf.py",
    "chars": 1685,
    "preview": "\"\"\"\nScript to push and load custom PyTorch models to/from the Hugging Face Hub.\n\"\"\"\n\nimport argparse\nimport torch\nfrom t"
  },
  {
    "path": "utils/data.py",
    "chars": 827,
    "preview": "import numpy as np\nfrom PIL import Image\n\ndef center_crop_arr(pil_image, image_size):\n    \"\"\"\n    Center cropping implem"
  },
  {
    "path": "utils/deepspeed.py",
    "chars": 2753,
    "preview": "def create_deepspeed_config(args):\n    ds_config = {\n        \"steps_per_print\": 1000,\n        \"train_batch_size\": args.g"
  },
  {
    "path": "utils/distributed.py",
    "chars": 2093,
    "preview": "import os\nimport torch\nimport subprocess\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables prin"
  },
  {
    "path": "utils/drop_path.py",
    "chars": 1594,
    "preview": "# from timm.models.layers import DropPath\nimport torch\n\ndef drop_path(x, drop_prob: float = 0., training: bool = False, "
  },
  {
    "path": "utils/ema.py",
    "chars": 703,
    "preview": "import torch\nfrom collections import OrderedDict\n\n@torch.no_grad()\ndef update_ema(ema_model, model, decay=0.9999):\n    \""
  },
  {
    "path": "utils/logger.py",
    "chars": 665,
    "preview": "import logging\nimport torch.distributed as dist\n\ndef create_logger(logging_dir):\n    \"\"\"\n    Create a logger that writes"
  },
  {
    "path": "utils/video.py",
    "chars": 3478,
    "preview": "import math\nimport numpy as np\nimport skvideo.io\nfrom PIL import Image\n\n# Shifts src_tf dim to dest dim\n# i.e. shift_dim"
  }
]

// ... and 9 more files (download for full content)

About this extraction

This page contains the full source code of the hustvl/ControlAR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 184 files (2.6 MB), approximately 696.4k tokens, and a symbol index with 813 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!