Repository: WZDTHU/NiT
Branch: main
Commit: 02fa292a7cd1
Files: 91
Total size: 565.6 KB
Directory structure:
gitextract_jijj82e1/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── c2i/
│ │ ├── nit_b_pack_merge_radio_65536.yaml
│ │ ├── nit_l_pack_merge_radio_16384.yaml
│ │ ├── nit_s_pack_merge_radio_65536.yaml
│ │ ├── nit_xl_pack_merge_radio_16384.yaml
│ │ └── nit_xxl_pack_merge_radio_8192.yaml
│ └── preprocess/
│ ├── imagenet1k_256x256.yaml
│ ├── imagenet1k_512x512.yaml
│ └── imagenet1k_native_resolution.yaml
├── nit/
│ ├── data/
│ │ ├── pack/
│ │ │ ├── __init__.py
│ │ │ ├── ennlshp.py
│ │ │ ├── lpfhp.py
│ │ │ ├── nnlshp.py
│ │ │ └── spfhp.py
│ │ ├── packed_c2i_data.py
│ │ └── sampler_util.py
│ ├── models/
│ │ ├── c2i/
│ │ │ └── nit_model.py
│ │ ├── nvidia_radio/
│ │ │ ├── hubconf.py
│ │ │ └── radio/
│ │ │ ├── __init__.py
│ │ │ ├── adaptor_base.py
│ │ │ ├── adaptor_generic.py
│ │ │ ├── adaptor_mlp.py
│ │ │ ├── adaptor_registry.py
│ │ │ ├── block.py
│ │ │ ├── cls_token.py
│ │ │ ├── common.py
│ │ │ ├── conv.py
│ │ │ ├── dinov2_arch.py
│ │ │ ├── dual_hybrid_vit.py
│ │ │ ├── enable_cpe_support.py
│ │ │ ├── enable_damp.py
│ │ │ ├── enable_spectral_reparam.py
│ │ │ ├── eradio_model.py
│ │ │ ├── extra_models.py
│ │ │ ├── extra_timm_models.py
│ │ │ ├── feature_normalizer.py
│ │ │ ├── forward_intermediates.py
│ │ │ ├── hf_model.py
│ │ │ ├── input_conditioner.py
│ │ │ ├── open_clip_adaptor.py
│ │ │ ├── radio_model.py
│ │ │ ├── vision_transformer_xpos.py
│ │ │ ├── vit_patch_generator.py
│ │ │ └── vitdet.py
│ │ └── utils/
│ │ ├── convs.py
│ │ ├── funcs.py
│ │ ├── norms.py
│ │ └── pos_embeds/
│ │ ├── flash_attn_rotary.py
│ │ ├── rope.py
│ │ └── sincos.py
│ ├── schedulers/
│ │ └── flow_matching/
│ │ ├── loss.py
│ │ └── samplers_c2i.py
│ └── utils/
│ ├── __init__.py
│ ├── deepspeed_zero_to_fp32.py
│ ├── ema.py
│ ├── eval_utils.py
│ ├── freeze.py
│ ├── gpu_memory_monitor.py
│ ├── lr_scheduler.py
│ ├── misc_utils.py
│ ├── model_utils.py
│ ├── train_utils.py
│ ├── util.py
│ ├── video_utils.py
│ └── warp_pos_idx.py
├── projects/
│ ├── evaluate/
│ │ └── adm_evaluator.py
│ ├── preprocess/
│ │ ├── image_latent_c2i.py
│ │ └── image_nr_latent_c2i.py
│ ├── sample/
│ │ └── sample_c2i_ddp.py
│ └── train/
│ └── packed_trainer_c2i.py
├── requirements.txt
├── scripts/
│ ├── preprocess/
│ │ ├── preorocess_in1k_256x256.sh
│ │ ├── preorocess_in1k_512x512.sh
│ │ └── preorocess_in1k_native_resolution.sh
│ ├── sample/
│ │ ├── sample_256x256.sh
│ │ ├── sample_512x512.sh
│ │ └── sample_768x768.sh
│ └── train/
│ ├── train_b_model.sh
│ ├── train_l_model.sh
│ ├── train_s_model.sh
│ ├── train_xl_model.sh
│ └── train_xxl_model.sh
├── setup.py
└── tools/
├── download_dataset_256x256.sh
├── download_dataset_512x512.sh
├── download_dataset_data_meta.sh
├── download_dataset_native_resolution.sh
├── download_dataset_sampler_meta.sh
└── pack_dataset.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
*.json
*.png
*.jpg
/checkpoints
/workdir
/datasets
/wandb
/samples
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
Native-Resolution Image Synthesis
Summary: We propose Native-resolution diffusion Transformer (NiT), a model that explicitly learns varing resolutions and aspect ratios within its denoising process. This significantly improves training efficiency and generalization capability. To the best of our knowledge, NiT firstly attains SOTA results on both $256\times256$ ($2.08$ FID) and $512\times512$ ($1.48$ FID) benchmarks in class-guided ImageNet generation. NiT can also generalizes to arbitrary resolutions and aspect ratios, such as $4.52$ FID on $1024\times1024$ resolution, $4.11$ FID on $432\times768$ resolution.

### 🚨 News
- `2025-9-18` NiT is accepted by NeurIPS 2025! 🍺
- `2025-6-3` We are delighted to introduce NiT, which is the first work to explicitly model native resolution image synthesis. We have released the code, pretrained models, and processed dataset of NiT.
### 1. Setup
First, clone the repo:
```bash
git clone https://github.com/WZDTHU/NiT.git && cd NiT
```
#### 1.1 Environment Setup
```bash
conda create -n nit_env python=3.10
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn
pip install -r requirements.txt
pip install -e .
```
#### 1.2 Model Zoo (WIP)
With a single model, NiT-XL can compete on multiple benchmarks and it achieves a dual SOTA on both ImageNet-$256\times256$ and $512\times512$ benchmarks.
| Model | Model Zoo | Model Size | FID-256x256 | FID-512x512 | FID-768x768 | FID-1024x1024 |
|---------------|------------|---------|------------|------------|------------|------------|
| NiT-XL-1000K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors) | 675M | 2.16 | 1.57 | 4.05 | 4.52 |
| NiT-XL-1500K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors) | 675M | 2.03 | 1.45 | - | - |
```bash
mkdir checkpoints
wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors" -O checkpoints/nit_xl_model_1000K.safetensors
wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors" -O checkpoints/nit_xl_model_1500K.safetensors
```
### 2. Sampling
#### 2.1 Sampling Hyper-parameters
The sampling hyper-parameters for NiT-XL-1000K are summarized as follows:
| Resolution | Solver | NFE | CFG - scale | CFG - interval | FID | sFID | IS | Prec. | Rec. |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 256 × 256 | SDE | 250 | 2.25 | [0.0, 0.7] | 2.16 | 6.34 | 253.44 | 0.79 | 0.62 |
| 512 × 512 | SDE | 250 | 2.05 | [0.0, 0.7] | 1.57 | 4.13 | 260.69 | 0.81 | 0.63 |
| 768 × 768 | ODE | 50 | 3.0 | [0.0, 0.7] | 4.05 | 8.77 | 262.31 | 0.83 | 0.52 |
| 1024 × 1024 | ODE | 50 | 3.0 | [0.0, 0.8] | 4.52 | 7.99 | 286.87 | 0.82 | 0.50 |
| 1536 × 1536 | ODE | 50 | 3.5 | [0.0, 0.9] | 6.51 | 9.97 | 230.10 | 0.83 | 0.42 |
| 2048 × 2048 | ODE | 50 | 4.5 | [0.0, 0.9] | 24.76 | 18.02 | 131.36 | 0.67 | 0.46 |
| 320 × 960 | ODE | 50 | 4.0 | [0.0, 0.9] | 16.85 | 17.79 | 189.18 | 0.71 | 0.38 |
| 432 × 768 | ODE | 50 | 2.75 | [0.0, 0.7] | 4.11 | 10.30 | 254.71 | 0.83 | 0.55 |
| 480 × 640 | ODE | 50 | 2.75 | [0.0, 0.7] | 3.72 | 8.23 | 284.94 | 0.83 | 0.54 |
| 640 × 480 | ODE | 50 | 2.5 | [0.0, 0.7] | 3.41 | 8.07 | 259.06 | 0.83 | 0.56 |
| 768 × 432 | ODE | 50 | 2.85 | [0.0, 0.7] | 5.27 | 9.92 | 218.78 | 0.80 | 0.55 |
| 960 × 320 | ODE | 50 | 4.5 | [0.0, 0.9] | 9.90 | 25.78 | 255.95 | 0.74 | 0.40 |
#### 2.2 Sampling Scripts
Sampling with NiT-XL-1000K model for $256\times256$-resolution images:
```bash
bash scripts/sample/sample_256x256.sh
```
Sampling with NiT-XL-1000K model for $512\times512$-resolution images:
```bash
bash scripts/sample/sample_512x512.sh
```
Sampling with NiT-XL-1000K model for $768\times768$-resolution images:
```bash
bash scripts/sample/sample_768x768.sh
```
### 3. Evaluation
The sampling generates a folder of samples to compute FID, Inception Score and
other metrics.
Note that we do not pack the generate samples as a `.npz` file, this does not affect the calculation of FID and other metrics.
Please follow the [ADM's TensorFlow
evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations)
to setup the conda-environment and download the reference batch.
```bash
wget -c "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" -O checkpoints/classify_image_graph_def.pb
```
Given the directory of the reference batch `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes:
```bash
python projects/evaluate/adm_evaluator.py $REFERENCE_DIR $SAMPLING_DIR
```
### 4. Training
#### 4.1 Dataset Setup
Currently, we provide all the [preprocessed dataset](https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K) for ImageNet1K. Please use the following commands to download the meta files and preprocessed latents.
```bash
mkdir datasets
mkdir datasets/imagenet1k
bash tools/download_dataset_256x256.sh
bash tools/download_dataset_512x512.sh
bash tools/download_dataset_native_resolution.sh
```
#### Preprocess ImageNet1K Locally
You can also preprocess the ImageNet1K dataset on your own.
Take $256\times256$-image preprocess as example, you should first modify the `data_dir` as your local ImageNet1K directory in `configs/preprocess/imagenet1k_256x256.yaml`.
Then run the preprocess script `scripts/preprocess/preorocess_in1k_256x256.sh`.
```bash
bash scripts/preprocess/preorocess_in1k_256x256.sh
```
The proprecessing procedure of $512\times512$-image and native-resolution image is similiar.
Modify the corresponding config file and run the script.
```bash
bash scripts/preprocess/preorocess_in1k_512x512.sh
bash scripts/preprocess/preorocess_in1k_native_resolution.sh
```
#### 4.2 Packing
As we pack multiple image instances with distinct resolution into one sequence, we need to pre-set the image indices of each pack before the training process.
#### Download meta-info files
Down all the data-meta files firstly, which restore the height, width and other information of each image.
```bash
bash tools/download_dataset_data_meta.sh
```
The above command will download four the data-meta files on `datasets/imagenet1k/data_meta` directory:
- `dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl`: data-meta file for $256\times256$-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl`, data-meta file for $512\times512$-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`, data-meta file for native-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl`, a merged file of the above three files.
The first two items of the native-resolution-image data-meta file (`dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`) are as follows:
```json
{"image_file": "n01601694/n01601694_11629.JPEG", "latent_file": "n01601694/n01601694_11629.safetensors", "ori_w": 580, "ori_h": 403, "latent_h": 12, "latent_w": 18, "image_h": 384, "image_w": 576, "type": "native-resolution"}
{"image_file": "n01601694/n01601694_11799.JPEG", "latent_file": "n01601694/n01601694_11799.safetensors", "ori_w": 500, "ori_h": 350, "latent_h": 10, "latent_w": 15, "image_h": 320, "image_w": 480, "type": "native-resolution"}
```
#### Sampler-Meta Download
Given the maximum length $L$, we pre-set the image indices of each pack before training.
Here we use the LPFHP (longest-pack-first histogram packing) algorithm to pack all the dataset.
You can download our preprocessed packed sampler-meta file using the following command.
```bash
bash tools/download_dataset_sampler_meta.sh
```
The above command will download three the data-meta files on `datasets/imagenet1k/sampler_meta` directory:
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json`: corresponds to $L=16384$.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json`: corresponds to $L=16384$. This is the setting in NiT-XL experiments.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json`, corresponds to $L=32768$.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json`, corresponds to $L=65536$.
#### Prepare the Packing (Sampler-Meta) on Your Own
NiT supports training with images of arbitrary resolutions and aspect ratios, you can also prepare the packing (sampler-meta) according to your own demands.
```bash
# generate the default sampler-meta
python tools/pack_dataset.py
# generate the sampelr-meta for fixed 256x256-resolution experiment with the maximum sequence length of 16384
python tools/pack_dataset.py --data-meta datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl --max-seq-len 16384
```
#### Download Image Encoder
For NiT-S (33M) model, we use RADIO-v2.5-H as image encoder for REPA-loss.
For other NiT models, we use RADIO-v2.5-H as our image encoder.
```bash
wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar" -O checkpoints/radio_v2.5-h.pth.tar
wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar" -O checkpoints/radio-v2.5-b_half.pth.tar
```
#### Training Scripts
The above steps setup the `packed_json`, `jsonl_dir`, and `latent_dirs` in `configs/c2i/nit_xl_pack_merge_radio_16384.yaml`.
Before training, please specify the `image_dir` as the directory of ImageNet1K dataset in your own machine.
To train the XL-model (675M):
```bash
bash scripts/train/train_xl_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_s_pack_merge_radio_65536.yaml` and train the base-model (131M):
```bash
bash scripts/train/train_s_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_b_pack_merge_radio_65536.yaml` and train the base-model (131M):
```bash
bash scripts/train/train_b_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_l_pack_merge_radio_16384.yaml` and train the base-model (457M):
```bash
bash scripts/train/train_l_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_xxl_pack_merge_radio_8192.yaml` and train the xxl-model (1.37B):
```bash
bash scripts/train/train_xxl_model.sh
```
### Citations
If you find the project useful, please kindly cite:
```bibtex
@article{wang2025native,
title={Native-Resolution Image Synthesis},
author={Wang, Zidong and Bai, Lei and Yue, Xiangyu and Ouyang, Wanli and Zhang, Yiyuan},
year={2025},
eprint={2506.03131},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
### License
This project is licensed under the Apache-2.0 license.
================================================
FILE: configs/c2i/nit_b_pack_merge_radio_65536.yaml
================================================
model:
transport:
path_type: linear
prediction: v
weighting: lognormal
network:
target: nit.models.c2i.nit_model.NiT
params:
class_dropout_prob: 0.1
num_classes: 1000
depth: 12
hidden_size: 768
patch_size: 1
in_channels: 32
num_heads: 12
qk_norm: True
encoder_depth: 4
z_dim: 1280
use_checkpoint: False
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
slice_vae: False
tile_vae: False
# repa encoder
enc_type: radio
enc_dir: checkpoints/radio_v2.5-h.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: improved_pack
dataset:
packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json
jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
latent_dirs: [
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
]
image_dir: /train
dataloader:
num_workers: 4
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 'c2i'}}
max_train_steps: 2000000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 5.0e-5
learning_rate_base_batch_size: 1
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [200000, 500000, 100000, 150000]
================================================
FILE: configs/c2i/nit_l_pack_merge_radio_16384.yaml
================================================
model:
transport:
path_type: linear
prediction: v
weighting: lognormal
network:
target: nit.models.c2i.nit_model.NiT
params:
class_dropout_prob: 0.1
num_classes: 1000
depth: 24
hidden_size: 1024
patch_size: 1
in_channels: 32
num_heads: 16
qk_norm: True
encoder_depth: 6
z_dim: 1280
use_checkpoint: False
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
slice_vae: False
tile_vae: False
# repa encoder
enc_type: radio
enc_dir: checkpoints/radio_v2.5-h.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: improved_pack
dataset:
packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json
jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
latent_dirs: [
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
]
image_dir: /train
dataloader:
num_workers: 4
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 'c2i'}}
max_train_steps: 2000000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 5.0e-5
learning_rate_base_batch_size: 4
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [200000, 500000, 100000, 150000]
================================================
FILE: configs/c2i/nit_s_pack_merge_radio_65536.yaml
================================================
model:
transport:
path_type: linear
prediction: v
weighting: lognormal
network:
target: nit.models.c2i.nit_model.NiT
params:
class_dropout_prob: 0.1
num_classes: 1000
depth: 12
hidden_size: 384
patch_size: 1
in_channels: 32
num_heads: 6
qk_norm: True
encoder_depth: 4
z_dim: 768
projector_dim: 768
use_checkpoint: False
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
slice_vae: False
tile_vae: False
# repa encoder
enc_type: radio
enc_dir: checkpoints/radio-v2.5-b_half.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: improved_pack
dataset:
packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json
jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
latent_dirs: [
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
]
image_dir: /train
dataloader:
num_workers: 4
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 'c2i'}}
max_train_steps: 2000000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 5.0e-5
learning_rate_base_batch_size: 1
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [200000, 500000, 100000, 150000]
================================================
FILE: configs/c2i/nit_xl_pack_merge_radio_16384.yaml
================================================
model:
transport:
path_type: linear
prediction: v
weighting: lognormal
network:
target: nit.models.c2i.nit_model.NiT
params:
class_dropout_prob: 0.1
num_classes: 1000
depth: 28
hidden_size: 1152
patch_size: 1
in_channels: 32
num_heads: 16
qk_norm: True
encoder_depth: 8
z_dim: 1280
use_checkpoint: False
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
slice_vae: False
tile_vae: False
# repa encoder
enc_type: radio
enc_dir: checkpoints/radio_v2.5-h.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: improved_pack
dataset:
packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json
jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
latent_dirs: [
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
]
image_dir: /train
dataloader:
num_workers: 4
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 'c2i'}}
max_train_steps: 2000000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 5.0e-5
learning_rate_base_batch_size: 4
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [200000, 500000, 100000, 150000]
================================================
FILE: configs/c2i/nit_xxl_pack_merge_radio_8192.yaml
================================================
model:
transport:
path_type: linear
prediction: v
weighting: lognormal
network:
target: nit.models.c2i.nit_model.NiT
params:
class_dropout_prob: 0.1
num_classes: 1000
depth: 40
hidden_size: 1536
patch_size: 1
in_channels: 32
num_heads: 24
qk_norm: True
encoder_depth: 8
z_dim: 1280
use_checkpoint: False
use_adaln_lora: True
adaln_lora_dim: 512
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
slice_vae: False
tile_vae: False
# repa encoder
enc_type: radio
enc_dir: checkpoints/radio_v2.5-h.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: improved_pack
dataset:
packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json
jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
latent_dirs: [
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
]
image_dir: /train
dataloader:
num_workers: 4
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 'c2i'}}
max_train_steps: 1000000
checkpointing_steps: 2000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 5.0e-5
learning_rate_base_batch_size: 4
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [200000, 500000, 100000]
================================================
FILE: configs/preprocess/imagenet1k_256x256.yaml
================================================
model:
vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
data:
dataset:
data_dir: /train
target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256
resolution: 256
dataloader:
num_workers: 1
batch_size: 64 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 't2i'}}
max_train_steps: 100000
checkpointing_steps: 200
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 256
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 4000
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [20000, 40000, 60000, 80000]
================================================
FILE: configs/preprocess/imagenet1k_512x512.yaml
================================================
model:
vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
data:
dataset:
data_dir: /train
target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
resolution: 512
dataloader:
num_workers: 1
batch_size: 16 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 't2i'}}
max_train_steps: 100000
checkpointing_steps: 200
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 256
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 4000
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [20000, 40000, 60000, 80000]
================================================
FILE: configs/preprocess/imagenet1k_native_resolution.yaml
================================================
model:
vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
data:
dataset:
data_dir: /train
target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution
min_image_size: 32
max_image_size: 2048
dataloader:
num_workers: 1
batch_size: 1 # Batch size (per device) for the training dataloader.
training:
tracker: null
tracker_kwargs: {'wandb': {'group': 't2i'}}
max_train_steps: 100000
checkpointing_steps: 200
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 256
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 4000
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [20000, 40000, 60000, 80000]
================================================
FILE: nit/data/pack/__init__.py
================================================
from .ennlshp import ENNLSHP
from .lpfhp import LPFHP
from .nnlshp import NNLSHP
from .spfhp import SPFHP
import json
import torch
import numpy as np
from tqdm import tqdm
def get_strategy(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens):
def generate_histogram(dataset_seq_lens):
histogram = np.zeros(max_seq_len, dtype=np.int64)
seq_lens, counts = np.unique(np.array(dataset_seq_lens), return_counts=True)
histogram[seq_lens - 1] = counts
return histogram
histogram = generate_histogram(dataset_seq_lens)
if algorithm == "SPFHP":
strategy = SPFHP(histogram, max_seq_len, max_seq_per_pack)
elif algorithm == "LPFHP":
strategy = LPFHP(histogram, max_seq_len, max_seq_per_pack)
elif algorithm == 'ENNLSHP':
strategy = ENNLSHP(histogram, max_seq_len, max_seq_per_pack)
elif algorithm == 'NNLSHP':
strategy = NNLSHP(histogram, max_seq_len, max_seq_per_pack)
else:
raise NotImplementedError("Algorithm type unsupported. Pass one of: LPFHP, SPFHP")
return strategy
def pack_dataset(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens, dataset_seq_idxs):
dataset_seqs = torch.stack([torch.tensor(dataset_seq_lens), torch.tensor(dataset_seq_idxs)])
strategy_set, strategy_repeat_count = get_strategy(
algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens
)
packed_indices = []
run_iters = sum(strategy_repeat_count)
progress_bar = tqdm(range(run_iters))
for i in range(len(strategy_repeat_count)):
strategy = strategy_set[i]
for _ in range(strategy_repeat_count[i]):
progress_bar.update(1)
ref_inds = []
for x in strategy:
ref_ind = torch.argwhere(dataset_seqs[0] == x)[-1]
dataset_seqs[0, ref_ind] = -1
ref_inds.append(ref_ind)
inds = dataset_seqs[1, ref_inds].ravel()
packed_indices.append(inds.tolist())
return packed_indices
================================================
FILE: nit/data/pack/ennlshp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/ennlshp.py
"""Extended Non-Negative least squares histogram-packing."""
import time
import numpy as np
from scipy import optimize, stats
from functools import lru_cache
def get_packing_matrix(strategy_set, max_sequence_length):
num_strategies = len(strategy_set)
A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)
for i, strategy in enumerate(strategy_set):
for seq_len in strategy:
A[seq_len - 1, i] += 1
return A
@lru_cache(maxsize=None)
def get_packing_strategies(start_length, minimum_increment, target_length, depth):
gap = target_length - start_length
strategies = []
# Complete the packing with exactly 1 number
if depth == 1:
if gap >= minimum_increment:
strategies.append([gap])
# Complete the sample in "depth" steps, recursively
else:
for new in range(minimum_increment, gap + 1):
new_gap = target_length - start_length - new
if new_gap == 0:
strategies.append([new])
else:
options = get_packing_strategies(start_length + new, new, target_length, depth - 1)
for option in options:
if len(option) > 0:
strategies.append([new] + option)
return strategies
def ENNLSHP(histogram, max_sequence_length, max_sequences_per_pack):
# List all unique ways of packing to the desired maximum sequence length
strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)
# Get the packing matrix corresponding to this list of packing strategies
A = get_packing_matrix(strategy_set, max_sequence_length)
# Weights that penalize the residual by the number of resulting padding tokens.
w0 = np.array([x + 1 for x in range(max_sequence_length)])
# construct the packing matrix
A_bar = np.zeros((2 * max_sequence_length, len(strategy_set) + max_sequence_length), "d")
# Base weighted matrix
A_bar[:max_sequence_length, : len(strategy_set)] = np.expand_dims(w0, -1) * A
# Higher weight to avoid positive residual
A_bar[max_sequence_length:, : len(strategy_set)] = np.expand_dims(10**6 * np.ones([max_sequence_length]), -1) * A
# negative diagonal unity matrix for mapping to residual
A_bar[max_sequence_length:, len(strategy_set) :] = np.expand_dims(
10**6 * np.ones([max_sequence_length]), -1
) * np.ones((max_sequence_length, max_sequence_length))
b_bar = np.zeros(2 * max_sequence_length)
# Apply weighting to histogram vector
b_bar[:max_sequence_length] = w0 * histogram
b_bar[max_sequence_length:] = 10**6 * np.ones([max_sequence_length]) * histogram
# Solve the packing problem
start = time.time()
strategy_residual, rnorm = optimize.nnls(A_bar, b_bar)
strategy_repeat_count = strategy_residual[: len(strategy_set)]
# Round the floating point solution to nearest integer
strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)
# Compute the residuals, shape: [max_sequence_length]
residual = histogram - A @ strategy_repeat_count
# Handle the left-over sequences; that is the positive part of residual
unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]
for l in unpacked_seqlen:
strategy = sorted([l, max_sequence_length - l]) # the depth 1 strategy
strategy_index = strategy_set.index(strategy)
strategy_repeat_count[strategy_index] += residual[l - 1]
# Re-compute the residual with the updated strategy_repeat_count
# This should now be strictly < 0
residual = histogram - A @ strategy_repeat_count
# Add padding based on deficit (negative residual portion of residual)
padding = np.where(residual < 0, -residual, 0)
# Calculate some basic statistics
duration = time.time() - start
sequence_lengths = np.arange(1, max_sequence_length + 1)
old_number_of_samples = histogram.sum()
new_number_of_samples = int(strategy_repeat_count.sum())
speedup_upper_bound = 1.0 / (
1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
)
num_padding_tokens_packed = (sequence_lengths * padding).sum()
efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)
print(
f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n"
f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
)
return strategy_set, strategy_repeat_count
================================================
FILE: nit/data/pack/lpfhp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/lpfhp.py
"""Longest-pack-first histogram-packing."""
from collections import defaultdict
import numpy as np
import time
def add_pack(pack, count, tmp, final, limit, offset, max_sequence_length=512):
"""Filter out packs that reached maximum length or number of components."""
# sanity checks
assert max_sequence_length - sum(pack) == offset, "Incorrect offset."
assert offset >= 0, "Too small offset."
assert offset < max_sequence_length, "Too large offset."
if len(pack) == limit or offset == 0:
final[offset].append((count, pack))
else:
tmp[offset].append((count, pack))
def LPFHP(histogram, max_sequence_length, max_sequences_per_pack, distribute=True):
"""Longest-pack-first histogram-packing."""
start = time.time()
reversed_histogram = np.flip(histogram)
# Initialize main strategy data dictionary.
# The key indicates how many tokens are left for full length.
# The value is a list of tuples, consisting of counts and respective packs.
# A pack is a (sorted) list of sequence length values that get concatenated.
tmp_strategies_per_length = defaultdict(list)
strategies_per_length = defaultdict(list)
if max_sequences_per_pack == "max":
max_sequences_per_pack = max_sequence_length
# Index i indicates here, how much space is left, due to reversed histogram
for i in range(max_sequence_length):
n_sequences_to_bin = reversed_histogram[i]
length_to_bin = max_sequence_length - i
offset = 0 # smallest possible offset for perfect fit
while n_sequences_to_bin > 0:
if (length_to_bin + offset) in tmp_strategies_per_length:
# extract worst pack that will get modified
n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()
# calculate how often the current sequence maximally fits in
repeat = min(1 + offset // length_to_bin, max_sequences_per_pack - len(pack))
# correct dependent on count
while n_sequences_to_bin // repeat == 0:
repeat -= 1
if not distribute:
repeat = 1
new_pack = pack + [length_to_bin] * repeat
count = min(n_sequences_to_pack, n_sequences_to_bin // repeat)
if n_sequences_to_pack > count:
# old pack gets reduced
n_sequences_to_pack -= count
tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))
n_sequences_to_bin -= count * repeat
else:
n_sequences_to_bin -= n_sequences_to_pack * repeat
add_pack(
new_pack,
count,
tmp_strategies_per_length,
strategies_per_length,
max_sequences_per_pack,
offset - (repeat - 1) * length_to_bin,
max_sequence_length,
)
# clean up to speed up main key search
if not tmp_strategies_per_length[length_to_bin + offset]:
tmp_strategies_per_length.pop(length_to_bin + offset)
# reset offset in case best fit changed
offset = 0
else:
offset += 1
# Does not fit anywhere. Create new pack.
if offset >= max_sequence_length - length_to_bin + 1:
# similar repetition but no dependence on pack.
repeat = min(max_sequence_length // length_to_bin, max_sequences_per_pack)
while n_sequences_to_bin // repeat == 0:
repeat -= 1
if not distribute:
repeat = 1
add_pack(
[length_to_bin] * repeat,
n_sequences_to_bin // repeat,
tmp_strategies_per_length,
strategies_per_length,
max_sequences_per_pack,
max_sequence_length - length_to_bin * repeat,
max_sequence_length,
)
n_sequences_to_bin -= n_sequences_to_bin // repeat * repeat
# merge all strategies
for key in tmp_strategies_per_length:
strategies_per_length[key].extend(tmp_strategies_per_length[key])
# flatten strategies dictionary
strategy_set = []
strategy_repeat_count = []
for key in strategies_per_length:
for count, pack in strategies_per_length[key]:
pack.reverse()
strategy_set.append(pack)
strategy_repeat_count.append(count)
# Summarize efficiency of solution
duration = time.time() - start
sequence_lengths = np.arange(1, max_sequence_length + 1)
strategy_repeat_count = np.array(strategy_repeat_count)
n_strategies = len(strategy_set)
old_number_of_samples = histogram.sum()
new_number_of_samples = strategy_repeat_count.sum()
sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])
total_tokens = max_sequence_length * new_number_of_samples
empty_tokens = sum(
[count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]
)
efficiency = 100 - empty_tokens / total_tokens * 100
speedup_upper_bound = 1.0 / (
1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
)
print(
f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}",
f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
)
return strategy_set, strategy_repeat_count
================================================
FILE: nit/data/pack/nnlshp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/nnlshp.py
"""Non-Negative least squares histogram-packing."""
import time
import numpy as np
from scipy import optimize, stats
from functools import lru_cache
def get_packing_matrix(strategy_set, max_sequence_length):
num_strategies = len(strategy_set)
A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)
for i, strategy in enumerate(strategy_set):
for seq_len in strategy:
A[seq_len - 1, i] += 1
return A
@lru_cache(maxsize=None)
def get_packing_strategies(start_length, minimum_increment, target_length, depth):
gap = target_length - start_length
strategies = []
# Complete the packing with exactly 1 number
if depth == 1:
if gap >= minimum_increment:
strategies.append([gap])
# Complete the sample in "depth" steps, recursively
else:
for new in range(minimum_increment, gap + 1):
new_gap = target_length - start_length - new
if new_gap == 0:
strategies.append([new])
else:
options = get_packing_strategies(start_length + new, new, target_length, depth - 1)
for option in options:
if len(option) > 0:
strategies.append([new] + option)
return strategies
def NNLSHP(histogram, max_sequence_length, max_sequences_per_pack):
# List all unique ways of packing to the desired maximum sequence length
strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)
# Get the packing matrix corresponding to this list of packing strategies
A = get_packing_matrix(strategy_set, max_sequence_length)
# Weights that penalize the residual on short sequences less.
penalization_cutoff = 8
w0 = np.ones([max_sequence_length])
w0[:penalization_cutoff] = 0.09
# Solve the packing problem
start = time.time()
strategy_repeat_count, rnorm = optimize.nnls(np.expand_dims(w0, -1) * A, w0 * histogram)
# Round the floating point solution to nearest integer
strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)
# Compute the residuals, shape: [max_sequence_length]
residual = histogram - A @ strategy_repeat_count
# Handle the left-over sequences, that is the positive part of residual
unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]
for l in unpacked_seqlen:
strategy = sorted([l, max_sequence_length - l]) # the depth 1 strategy
strategy_index = strategy_set.index(strategy)
strategy_repeat_count[strategy_index] += residual[l - 1]
# Re-compute the residual with the updated strategy_repeat_count
# This should now be strictly < 0
residual = histogram - A @ strategy_repeat_count
# Add padding based on deficit (negative residual portion of residual)
padding = np.where(residual < 0, -residual, 0)
# Calculate some basic statistics
duration = time.time() - start
sequence_lengths = np.arange(1, max_sequence_length + 1)
old_number_of_samples = histogram.sum()
new_number_of_samples = int(strategy_repeat_count.sum())
speedup_upper_bound = 1.0 / (
1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
)
num_padding_tokens_packed = (sequence_lengths * padding).sum()
efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)
print(
f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n"
f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
)
return strategy_set, strategy_repeat_count
================================================
FILE: nit/data/pack/spfhp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/spfhp.py
"""Shortest-pack-first histogram-packing."""
from collections import defaultdict
import numpy as np
import time
def add_pack(pack, count, tmp, final, limit, offset):
"""Filter out packs that reached maximum length or number of sequences."""
if len(pack) == limit or offset == 0:
final[offset].append((count, pack))
else:
tmp[offset].append((count, pack))
def SPFHP(histogram, max_sequence_length, max_sequences_per_pack):
"""Shortest-pack-first histogram-packing."""
start = time.time()
reversed_histogram = np.flip(histogram)
# Initialize main strategy data dictionary.
# The key indicates how many tokens are left for full length.
# The value is a list of tuples, consisting of counts and respective packs.
# A pack is a (sorted) list of sequence length values that get concatenated.
tmp_strategies_per_length = defaultdict(list)
strategies_per_length = defaultdict(list)
# Index i indicates here, how much space is left, due to reversed histogram
for i in range(max_sequence_length):
n_sequences_to_bin = reversed_histogram[i]
length_to_bin = max_sequence_length - i
offset = i + 1 # largest possible offset
while n_sequences_to_bin > 0:
if (length_to_bin + offset) in tmp_strategies_per_length:
# extract shortest pack that will get modified
n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()
new_pack = pack + [length_to_bin]
count = min(n_sequences_to_pack, n_sequences_to_bin)
if n_sequences_to_pack > n_sequences_to_bin:
# old pack gets reduced
n_sequences_to_pack -= n_sequences_to_bin
tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))
n_sequences_to_bin = 0
else:
n_sequences_to_bin -= n_sequences_to_pack
add_pack(
new_pack,
count,
tmp_strategies_per_length,
strategies_per_length,
max_sequences_per_pack,
offset,
)
# clean up to speed up main key search
if not tmp_strategies_per_length[length_to_bin + offset]:
tmp_strategies_per_length.pop(length_to_bin + offset)
else:
offset -= 1
# Does not fit anywhere. Create new pack.
if offset < 0:
add_pack(
[length_to_bin],
n_sequences_to_bin,
tmp_strategies_per_length,
strategies_per_length,
max_sequences_per_pack,
i,
)
n_sequences_to_bin = 0
# merge all strategies
for key in tmp_strategies_per_length:
strategies_per_length[key].extend(tmp_strategies_per_length[key])
# flatten strategies dictionary
strategy_set = []
strategy_repeat_count = []
for key in strategies_per_length:
for count, pack in strategies_per_length[key]:
pack.reverse()
strategy_set.append(pack)
strategy_repeat_count.append(count)
# Summarize efficiency of solution
duration = time.time() - start
sequence_lengths = np.arange(1, max_sequence_length + 1)
strategy_repeat_count = np.array(strategy_repeat_count)
n_strategies = len(strategy_set)
old_number_of_samples = histogram.sum()
new_number_of_samples = strategy_repeat_count.sum()
sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])
total_tokens = max_sequence_length * new_number_of_samples
empty_tokens = sum(
[count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]
)
efficiency = 100 - empty_tokens / total_tokens * 100
speedup_upper_bound = 1.0 / (
1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
)
print(
f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n",
f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
)
return strategy_set, np.array(strategy_repeat_count)
================================================
FILE: nit/data/packed_c2i_data.py
================================================
import os
import datetime
import torchvision
import numpy as np
import torch
import ast
import json
import time
from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import load_file
from einops import rearrange
from functools import partial
from torchvision.transforms.functional import hflip
from .sampler_util import get_train_sampler, get_packed_batch_sampler
logger = get_logger(__name__, log_level="INFO")
PATCH_SIZE = 1
def resize_arr(pil_image, height, width):
pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)
return pil_image
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
def packed_collate_fn(batch):
packed_latent = []
label = []
hw_list = []
image = []
for data in batch:
C, H, W = data['latent'].shape
latent = rearrange(
data['latent'], 'c (h p1) (w p2) -> (h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE
)
packed_latent.append(latent)
hw_list.append([H/PATCH_SIZE, W/PATCH_SIZE])
label.append(data['label'])
image.append(data['image'])
packed_latent = torch.concat(packed_latent)
label = torch.tensor(label)
hw_list = torch.tensor(hw_list, dtype=torch.int32)
return dict(image=image, latent=packed_latent, label=label, hw_list=hw_list)
class ImprovedPackedImageNetLatentDataset(Dataset):
def __init__(self, packed_json, jsonl_dir, data_types, latent_dirs, image_dir):
super().__init__()
assert len(data_types) == len(latent_dirs)
self.type_to_dir = dict()
for i, data_type in enumerate(data_types):
self.type_to_dir[data_type] = latent_dirs[i]
self.image_dir = image_dir
with open(packed_json, 'r') as fp:
self.packed_dataset = json.load(fp)
with open(jsonl_dir, 'r') as fp:
self.dataset = [json.loads(line) for line in fp]
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
data_meta = self.dataset[index]
data_item = dict()
data_type = data_meta['type']
latent_file = os.path.join(self.type_to_dir[data_type], data_meta['latent_file'])
image_file = os.path.join(self.image_dir, data_meta['image_file'])
data = load_file(latent_file)
height = data_meta['latent_h'] * 16
width = data_meta['latent_w'] * 16
if data_type == 'native-resolution':
preprocess = partial(resize_arr, height=height, width=width)
else:
assert height == width
preprocess = partial(center_crop_arr, image_size=height)
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: preprocess(pil_image=pil_image)),
transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
])
rand_idx = torch.randint(low=0, high=2, size=(1,)).item()
data_item['image'] = transform(Image.open(image_file).convert("RGB"))[rand_idx]
data_item['latent'] = data['latent'][rand_idx]
data_item['label'] = data['label']
return data_item
class C2ILoader():
def __init__(self, data_config):
super().__init__()
self.batch_size = data_config.dataloader.batch_size
self.num_workers = data_config.dataloader.num_workers
self.data_type = data_config.data_type
if data_config.data_type == 'improved_pack':
self.train_dataset = ImprovedPackedImageNetLatentDataset(
**OmegaConf.to_container(data_config.dataset)
)
else:
raise NotImplementedError
self.test_dataset = None
self.val_dataset = None
def train_len(self):
return len(self.train_dataset)
def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
sampler = get_train_sampler(
self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed
)
if self.data_type == 'improved_pack':
batch_sampler = get_packed_batch_sampler(
self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed
)
return DataLoader(
self.train_dataset,
batch_sampler=batch_sampler,
collate_fn=packed_collate_fn,
num_workers=self.num_workers,
pin_memory=True,
)
else:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
)
def test_dataloader(self):
return None
def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
================================================
FILE: nit/data/sampler_util.py
================================================
import torch
import json
# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps,
resume_step, seed):
sample_indices = torch.empty([max_steps * global_batch_size // world_size],
dtype=torch.long)
epoch_id, fill_ptr, offs = 0, 0, 0
while fill_ptr < sample_indices.size(0):
g = torch.Generator()
g.manual_seed(seed + epoch_id)
epoch_sample_indices = torch.randperm(len(dataset), generator=g)
epoch_id += 1
epoch_sample_indices = epoch_sample_indices[
(rank + offs) % world_size::world_size
]
offs = (offs + world_size - len(dataset) % world_size) % world_size
epoch_sample_indices = epoch_sample_indices[
:sample_indices.size(0) - fill_ptr
]
sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \
epoch_sample_indices
fill_ptr += epoch_sample_indices.size(0)
return sample_indices[resume_step * global_batch_size // world_size:].tolist()
def get_packed_batch_sampler(
dataset, rank, world_size, max_steps, resume_step, seed
):
sample_indices = [None for _ in range(max_steps)]
epoch_id, fill_ptr, offs = 0, 0, 0
while fill_ptr < len(sample_indices):
g = torch.Generator()
g.manual_seed(seed + epoch_id)
epoch_sample_indices = torch.randperm(len(dataset), generator=g)
epoch_id += 1
epoch_sample_indices = epoch_sample_indices[
(rank + offs) % world_size::world_size
]
offs = (offs + world_size - len(dataset) % world_size) % world_size
epoch_sample_indices = epoch_sample_indices[
:len(sample_indices) - fill_ptr
]
sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [
dataset[i] for i in epoch_sample_indices
]
fill_ptr += epoch_sample_indices.size(0)
return sample_indices[resume_step:]
================================================
FILE: nit/models/c2i/nit_model.py
================================================
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Mlp
from einops import rearrange, repeat
from flash_attn import flash_attn_varlen_func
from nit.models.utils.funcs import get_parameter_dtype
from nit.models.utils.pos_embeds.rope import VisionRotaryEmbedding, rotate_half
from typing import Optional
def modulate(x, shift, scale):
return x * (1 + scale) + shift
def build_mlp(hidden_size, projector_dim, z_dim):
return nn.Sequential(
nn.Linear(hidden_size, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, z_dim),
)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def positional_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
self.timestep_embedding = self.positional_embedding
t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def forward(self, labels):
embeddings = self.embedding_table(labels)
return embeddings
#################################################################################
# Attention Block #
#################################################################################
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, cu_seqlens, freqs_cos, freqs_sin) -> torch.Tensor:
N, C = x.shape
qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
ori_dtype = qkv.dtype
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q * freqs_cos + rotate_half(q) * freqs_sin
k = k * freqs_cos + rotate_half(k) * freqs_sin
q, k = q.to(ori_dtype), k.to(ori_dtype)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
x = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
#################################################################################
# Core NiT Model #
#################################################################################
class NiTBlock(nn.Module):
"""
A NiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
)
use_adaln_lora = block_kwargs.get('use_adaln_lora', False)
if use_adaln_lora:
adaln_lora_dim = block_kwargs['adaln_lora_dim']
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, adaln_lora_dim, bias=True),
nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True)
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, cu_seqlens, freqs_cos, freqs_sin):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(c).chunk(6, dim=-1)
)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of NiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class NiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000,
encoder_depth=4,
projector_dim=2048,
z_dim=768,
use_checkpoint: bool = False,
custom_freqs: str = 'normal',
theta: int = 10000,
max_pe_len_h: Optional[int] = None,
max_pe_len_w: Optional[int] = None,
decouple: bool = False,
ori_max_pe_len: Optional[int] = None,
**block_kwargs # fused_attn
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.num_classes = num_classes
self.encoder_depth = encoder_depth
self.use_checkpoint = use_checkpoint
self.x_embedder = PatchEmbed(
input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False
)
self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
self.rope = VisionRotaryEmbedding(
head_dim=hidden_size//num_heads, custom_freqs=custom_freqs, theta=theta,
max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple,
ori_max_pe_len=ori_max_pe_len
)
self.projector = build_mlp(hidden_size, projector_dim, z_dim)
self.blocks = nn.ModuleList([
NiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in NiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x, patch_size=None):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def get_rope(self, hw_list):
grids = []
for h, w in hw_list:
grid_h = torch.arange(h)
grid_w = torch.arange(w)
grid = torch.meshgrid(grid_h, grid_w, indexing='xy')
grid = torch.stack(grid, dim=0).reshape(2, -1)
grids.append(grid)
grids = torch.cat(grids, dim=-1)
freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grids)
return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
def forward(self, x, t, y, hw_list, return_zs=False, return_logvar=False):
"""
Forward pass of NiT.
x: (N, C, p, p) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) # (N, C, p, p) -> (N, 1, D), where T = H * W / patch_size ** 2
x = x.squeeze(1) # (N, D)
B = hw_list.shape[0]
freqs_cos, freqs_sin = self.get_rope(hw_list) # (N, D_h)
seqlens = hw_list[:, 0] * hw_list[:, 1]
cu_seqlens = torch.cat([
torch.tensor([0], device=hw_list.device, dtype=torch.int),
torch.cumsum(seqlens, dim=0, dtype=torch.int)
])
# timestep and class embedding
t_embed = self.t_embedder(t) # (B, D)
y = self.y_embedder(y) # (B, D)
c = t_embed + y # (B, D)
# (B, D) -> (N, D)
c = torch.cat([c[i].unsqueeze(0).repeat(seqlens[i], 1) for i in range(B)], dim=0)
zs=[]
for i, block in enumerate(self.blocks):
if not self.use_checkpoint:
x = block(x, c, cu_seqlens, freqs_cos, freqs_sin) # (N, D)
else:
x = torch.utils.checkpoint.checkpoint(
self.ckpt_wrapper(block), x, c, cu_seqlens, freqs_cos, freqs_sin
)
if (i + 1) == self.encoder_depth and return_zs:
zs = [self.projector(x)]
x = self.final_layer(x, c) # (N, out_channels * patch_size ** 2)
# (N, out_channels * patch_size ** 2) -> (N, out_channels, p, p)
x = rearrange(x, 'n (c p1 p2) -> n c p1 p2', p1=self.patch_size, p2=self.patch_size)
if return_zs:
return x, zs
else:
return x
def ckpt_wrapper(self, module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
================================================
FILE: nit/models/nvidia_radio/hubconf.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
dependencies = ["torch", "timm", "einops"]
import os
from typing import Dict, Any, Optional, Union, List
import warnings
import torch
from torch.hub import load_state_dict_from_url
from timm.models import clean_state_dict
from .radio.adaptor_registry import adaptor_registry
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from .radio.enable_damp import configure_damp_from_args
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio.radio_model import RADIOModel, create_model_from_args
from .radio.input_conditioner import get_default_conditioner
from .radio.vitdet import apply_vitdet_arch, VitDetArgs
def radio_model(
version: str = "",
progress: bool = True,
adaptor_names: Union[str, List[str]] = None,
vitdet_window_size: Optional[int] = None,
return_checkpoint: bool = False,
support_packing: bool=False,
**kwargs,
) -> RADIOModel:
if not version:
version = DEFAULT_VERSION
if os.path.isfile(version):
chk = torch.load(version, map_location="cpu", weights_only=False)
resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
else:
resource = RESOURCE_MAP[version]
chk = load_state_dict_from_url(
resource.url, progress=progress, map_location="cpu", weights_only=False,
)
if "state_dict_ema" in chk:
state_dict = chk["state_dict_ema"]
chk['args'].spectral_reparam = False
else:
state_dict = chk["state_dict"]
args = chk["args"]
args.support_packing = support_packing
mod = create_model_from_args(args)
mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")
if args.spectral_reparam:
configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)
if getattr(args, 'damp', None):
configure_damp_from_args(mod, args)
state_dict = clean_state_dict(state_dict)
key_warn = mod.load_state_dict(mod_state_dict, strict=False)
if key_warn.missing_keys:
warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
if key_warn.unexpected_keys:
warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')
if chk['args'].spectral_reparam:
# Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
# the method is that instead of there being a `weight` tensor for certain Linear layers
# in the model, we make it a dynamically computed function. During training, this
# helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
# Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
# once, during this function call, and replace the parametrization with the realized weights.
# This makes the model run faster, and also use less memory.
disable_spectral_reparam(mod)
chk['args'].spectral_reparam = False
conditioner = get_default_conditioner()
conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))
dtype = getattr(chk['args'], 'dtype', torch.float32)
mod.to(dtype=dtype)
conditioner.dtype = dtype
cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
if cls_token_per_teacher:
name_to_idx_map = dict()
for i, t in enumerate(chk['args'].teachers):
if t.get('use_summary', True):
name = t['name']
if name not in name_to_idx_map:
name_to_idx_map[name] = i
summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
else:
summary_idxs = torch.tensor([0], dtype=torch.int64)
if adaptor_names is None:
adaptor_names = []
elif isinstance(adaptor_names, str):
adaptor_names = [adaptor_names]
teachers = chk["args"].teachers
adaptors = dict()
for adaptor_name in adaptor_names:
for tidx, tconf in enumerate(teachers):
if tconf["name"] == adaptor_name:
break
else:
raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')
ttype = tconf["type"]
pf_idx_head = f'_heads.{tidx}'
pf_name_head = f'_heads.{adaptor_name}'
pf_idx_feat = f'_feature_projections.{tidx}'
pf_name_feat = f'_feature_projections.{adaptor_name}'
adaptor_state = dict()
for k, v in state_dict.items():
if k.startswith(pf_idx_head):
adaptor_state['summary' + k[len(pf_idx_head):]] = v
elif k.startswith(pf_name_head):
adaptor_state['summary' + k[len(pf_name_head):]] = v
elif k.startswith(pf_idx_feat):
adaptor_state['feature' + k[len(pf_idx_feat):]] = v
elif k.startswith(pf_name_feat):
adaptor_state['feature' + k[len(pf_name_feat):]] = v
adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
adaptor.head_idx = tidx if cls_token_per_teacher else 0
adaptors[adaptor_name] = adaptor
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer = None
if feat_norm_sd:
feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
feature_normalizer.load_state_dict(feat_norm_sd)
inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer = None
if inter_feat_norm_sd:
inter_feature_normalizer = IntermediateFeatureNormalizer(
*inter_feat_norm_sd['means'].shape[:2],
rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
dtype=dtype
)
inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
radio = RADIOModel(
mod,
conditioner,
summary_idxs=summary_idxs,
patch_size=resource.patch_size,
max_resolution=resource.max_resolution,
window_size=vitdet_window_size,
preferred_resolution=resource.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)
if vitdet_window_size is not None:
apply_vitdet_arch(
mod,
VitDetArgs(
vitdet_window_size,
radio.num_summary_tokens,
num_windowed=resource.vitdet_num_windowed,
num_global=resource.vitdet_num_global,
),
)
if return_checkpoint:
return radio, chk
return radio
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
mod_state_dict = {
k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
}
return mod_state_dict
================================================
FILE: nit/models/nvidia_radio/radio/__init__.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# Register the adaptors
from .adaptor_registry import adaptor_registry
from . import open_clip_adaptor
from .adaptor_base import AdaptorInput, RadioOutput, AdaptorBase
# Enable support for other model types via the timm register_model mechanism
from . import extra_timm_models
from . import extra_models
from . import vision_transformer_xpos
================================================
FILE: nit/models/nvidia_radio/radio/adaptor_base.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import NamedTuple, Optional
import torch
from torch import nn
import torch.nn.functional as F
class AdaptorInput(NamedTuple):
images: torch.Tensor
summary: torch.Tensor
features: torch.Tensor
feature_fmt: str
patch_size: int
class RadioOutput(NamedTuple):
summary: torch.Tensor
features: torch.Tensor
def to(self, *args, **kwargs):
return RadioOutput(
self.summary.to(*args, **kwargs) if self.summary is not None else None,
self.features.to(*args, **kwargs) if self.features is not None else None,
)
class AdaptorBase(nn.Module):
def forward(self, input: AdaptorInput) -> RadioOutput:
raise NotImplementedError("Subclasses must implement this!")
================================================
FILE: nit/models/nvidia_radio/radio/adaptor_generic.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
import torch
from torch import nn
import torch.nn.functional as F
from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config
class GenericAdaptor(AdaptorBase):
def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
super().__init__()
extra_args = dict()
ups = None
ups_rank = None
if adaptor_config is not None:
ups = adaptor_config.get('fd_upsample_factor', None)
ups_rank = adaptor_config.get('fd_upsample_rank', None)
elif mlp_config is not None:
ups = mlp_config["feature"].get('upsample_factor', None)
ups_rank = mlp_config["feature"].get('upsample_rank', None)
if ups is not None:
extra_args['upsample_factor'] = ups
extra_args['upsample_rank'] = ups_rank
if state is not None:
spectral_heads = getattr(main_config, 'spectral_heads', False)
self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
else:
assert mlp_config is not None, "Config must not be None if state is None"
self.head_mlp = create_mlp_from_config(
main_config.mlp_version,
mlp_config["summary"]["input_dim"],
mlp_config["summary"]["hidden_dim"],
mlp_config["summary"]["output_dim"],
mlp_config["summary"]["num_inner"],
)
self.feat_mlp = create_mlp_from_config(
main_config.mlp_version,
mlp_config["feature"]["input_dim"],
mlp_config["feature"]["hidden_dim"],
mlp_config["feature"]["output_dim"],
mlp_config["feature"]["num_inner"],
**extra_args
)
def forward(self, input: AdaptorInput) -> RadioOutput:
# Convert input'd type to the type of the first parameter of the adaptor.
first_param = next(self.parameters())
summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
if input.feature_fmt == 'NCHW':
feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
.permute(0, 3, 1, 2)
)
return RadioOutput(summary, feat)
================================================
FILE: nit/models/nvidia_radio/radio/adaptor_mlp.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
from typing import Dict, Optional
import torch
from torch import nn
from einops import rearrange
from timm.models.vision_transformer import Block
from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
class MLP(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int,
num_inner: int = 0, device: torch.device = None, **kwargs):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
self.norm = nn.LayerNorm(hidden_size, device=device)
self.relu = nn.ReLU()
inner = []
for _ in range(num_inner):
inner.extend([
nn.Linear(hidden_size, hidden_size, device=device),
nn.LayerNorm(hidden_size, device=device),
nn.ReLU(),
])
if inner:
self.inner = nn.Sequential(*inner)
else:
self.inner = nn.Identity()
self.fc2 = nn.Linear(hidden_size, output_size, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.norm(x)
x = self.relu(x)
x = self.inner(x)
x = self.fc2(x)
return x
class MLP2(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int,
num_inner: int = 0,
pre_norm: bool = False, device: torch.device = None,
upsample_factor: int = 1,
upsample_rank: int = None,
from_config: bool = False,
**kwargs):
super().__init__()
self.pre_norm = nn.Sequential(
nn.LayerNorm(input_size),
nn.GELU(),
) if pre_norm else nn.Identity()
self.upsample_factor = upsample_factor
sq_ups = upsample_factor ** 2
self._real_output_dim = output_size // sq_ups
# hidden_size *= upsample_factor
# output_size *= (upsample_factor ** 2)
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
blocks = []
for _ in range(num_inner):
blocks.append(nn.Sequential(
nn.LayerNorm(hidden_size, device=device),
nn.GELU(),
nn.Linear(hidden_size, hidden_size, device=device),
))
self.blocks = nn.ModuleList(blocks)
self.final = nn.Sequential(
nn.LayerNorm(hidden_size, device=device),
nn.GELU(),
nn.Linear(hidden_size, output_size, device=device),
)
def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
x = self.pre_norm(x)
x = self.fc1(x)
for block in self.blocks:
x = x + block(x)
x = self.final(x)
if self.upsample_factor > 1:
if images is None:
raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
if patch_size is None:
raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
h, w = tuple(d // patch_size for d in images.shape[-2:])
x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
c=self._real_output_dim)
return x
MLP_FACTORY = {
'v1': MLP,
'v2': MLP2,
}
def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
state = {
k[len(prefix):]: v
for k, v in state.items()
if k.startswith(prefix)
}
return state
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
state = strip_prefix(state, prefix)
weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
if version == 'v1':
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
output_dim = state[f'fc2.{weight_suffix}'].shape[0]
for num_inner in range(1000):
k = f'inner.{num_inner}.0.weight'
if k not in state:
break
elif version == 'v2':
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
output_dim = state[f'final.2.{weight_suffix}'].shape[0]
for num_inner in range(1000):
k = f'blocks.{num_inner}.0.weight'
if k not in state:
break
else:
raise ValueError(f'Unsupported MLP version: {version}')
return input_dim, hidden_dim, output_dim, num_inner
def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
return ret
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
state = strip_prefix(state, prefix)
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
if spectral_weights:
enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
ret.load_state_dict(state)
if spectral_weights:
disable_spectral_reparam(ret)
return ret
================================================
FILE: nit/models/nvidia_radio/radio/adaptor_registry.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import Dict, Any
import torch
from .adaptor_generic import GenericAdaptor, AdaptorBase
dict_t = Dict[str, Any]
state_t = Dict[str, torch.Tensor]
class AdaptorRegistry:
def __init__(self):
self._registry = {}
def register_adaptor(self, name):
def decorator(factory_function):
if name in self._registry:
raise ValueError(f"Model '{name}' already registered")
self._registry[name] = factory_function
return factory_function
return decorator
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
if name not in self._registry:
return GenericAdaptor(main_config, adaptor_config, state)
return self._registry[name](main_config, adaptor_config, state)
# Creating an instance of the registry
adaptor_registry = AdaptorRegistry()
================================================
FILE: nit/models/nvidia_radio/radio/block.py
================================================
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Block modules
"""
import torch
import torch.nn as nn
from timm.models.layers import DropPath
from .conv import Conv
# from .transformer import TransformerBlock
__all__ = ('C2f', 'Bottleneck',)
class C2f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
if drop_path is None:
drop_path = [0.0] * n
self.c = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class Bottleneck(nn.Module):
"""Standard bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
"""'forward()' applies the YOLOv5 FPN to input data."""
return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
================================================
FILE: nit/models/nvidia_radio/radio/cls_token.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Optional
import torch
from torch import nn
class ClsToken(nn.Module):
def __init__(self, ndim: int,
num_tokens: int = 1,
enabled: bool = True,
register_multiple: Optional[int] = None,
num_registers: Optional[int] = None,
):
super().__init__()
self.ndim = ndim
self.enabled = enabled
self.num_registers = 0
self.num_tokens = num_tokens
if enabled:
if num_registers:
self.num_registers = num_registers
elif register_multiple:
self.num_registers = register_multiple - (num_tokens % register_multiple)
scale = ndim ** -0.5
self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
else:
self.token = None
self.num_patches = self.num_tokens + self.num_registers
def disable(self):
self.token = None
self.enabled = False
def forward(self, x: torch.Tensor):
if self.token is None:
return x
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
x = torch.cat([
token,
x,
], dim=1)
return x
def no_weight_decay(self):
return [
'token',
]
================================================
FILE: nit/models/nvidia_radio/radio/common.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from dataclasses import dataclass
from typing import Optional
from .radio_model import Resolution
@dataclass
class RadioResource:
url: str
patch_size: int
max_resolution: int
preferred_resolution: Resolution
vitdet_num_windowed: Optional[int] = None
vitdet_num_global: Optional[int] = None
RESOURCE_MAP = {
# RADIOv2.5
"radio_v2.5-b": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=(768, 768),
vitdet_num_global=4,
),
"radio_v2.5-l": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=(768, 768),
vitdet_num_global=4,
),
"radio_v2.5-h": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=(768, 768),
vitdet_num_global=4,
),
"radio_v2.5-h-norm": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=(768, 768),
vitdet_num_global=4,
),
"radio_v2.5-g": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true",
patch_size=14,
max_resolution=1792,
preferred_resolution=(896, 896),
vitdet_num_global=8,
),
# RADIO
"radio_v2.1": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=Resolution(432, 432),
vitdet_num_windowed=5,
),
"radio_v2": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=Resolution(432, 432),
vitdet_num_windowed=5,
),
"radio_v1": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true",
patch_size=14,
max_resolution=1050,
preferred_resolution=Resolution(378, 378),
),
# E-RADIO
"e-radio_v2": RadioResource(
"https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=Resolution(512, 512),
),
# C-RADIO
"c-radio_v2.5-g": RadioResource(
"https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar",
patch_size=16,
max_resolution=2048,
preferred_resolution=(768, 768),
vitdet_num_global=8,
),
"c-radio_v3-l": RadioResource(
# NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
# and accept the license terms.
"https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
patch_size=16,
max_resolution=2048,
preferred_resolution=Resolution(512, 512),
),
}
DEFAULT_VERSION = "radio_v2.5-h"
================================================
FILE: nit/models/nvidia_radio/radio/conv.py
================================================
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Convolution modules
"""
import math
import numpy as np
import torch
import torch.nn as nn
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
# Pavlo's implementation with switch to deploy
class Conv(nn.Module):
default_act = nn.SiLU() # default activation
def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
super().__init__()
self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
if 1:
self.bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
@torch.no_grad()
def switch_to_deploy(self):
if not isinstance(self.bn, nn.Identity):
# return 1
c, bn = self.conv, self.bn
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
# m = torch.nn.Conv2d(w.size(1) * c.groups,
# w.size(0),
# w.shape[2:],
# stride=c.stride,
# padding=c.padding,
# dilation=c.dilation,
# groups=c.groups)
self.conv.weight.data.copy_(w)
self.conv.bias = nn.Parameter(b)
# self.conv.bias.data.copy_(b)
# self.conv = m.to(c.weight.device)
self.bn = nn.Identity()
================================================
FILE: nit/models/nvidia_radio/radio/dinov2_arch.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
# Nvidia
# NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking,
# but also because Huggingface does a string replace of `gamma` to something else when loading the model state,
# and this breaks loading of this model.
from enum import Enum
from functools import partial
import logging
import math
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
_torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention')
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
else:
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if _torch_has_sdpa:
x = F.scaled_dot_product_attention(
q, k, v,
is_causal=False,
dropout_p=self.attn_drop.p if self.training else 0.,
scale=self.scale,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
if not XFORMERS_AVAILABLE:
SwiGLU = SwiGLUFFN
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, torch.Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.grandma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.grandma) if self.inplace else x * self.grandma
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
# Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation
# of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either
# format
key_a = f'{prefix}gamma'
key_b = f'{prefix}grandma'
if key_a in state_dict:
gamma = state_dict[key_a]
elif key_b in state_dict:
gamma = state_dict[key_b]
else:
if strict:
raise KeyError(f"Couldn't find the key {key_a} nor {key_b} in the state dict!")
else:
missing_keys.append(key_a)
missing_keys.append(key_b)
unexpected_keys.extend(state_dict.keys())
gamma = None
if gamma is not None:
self.grandma.data.copy_(gamma)
# return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: torch.Tensor) -> torch.Tensor:
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, torch.Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
if not XFORMERS_AVAILABLE:
raise AssertionError("xFormers is required for using nested tensors")
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
def drop_add_residual_stochastic_depth(
x: torch.Tensor,
residual_func: Callable[[torch.Tensor], torch.Tensor],
sample_drop_ratio: float = 0.0,
) -> torch.Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[torch.Tensor],
residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> torch.Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
class Weights(Enum):
LVD142M = "LVD142M"
def _make_dinov2_model(
*,
arch_name: str = "vit_large",
img_size: int = 518,
patch_size: int = 14,
init_values: float = 1.0,
ffn_layer: str = "mlp",
block_chunks: int = 0,
num_register_tokens: int = 0,
interpolate_antialias: bool = False,
interpolate_offset: float = 0.1,
weights: Union[Weights, str] = Weights.LVD142M,
**kwargs,
):
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")
vit_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
init_values=init_values,
ffn_layer=ffn_layer,
block_chunks=block_chunks,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
)
vit_kwargs.update(**kwargs)
model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs)
return model
def dinov2_vits14(**kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_small", **kwargs)
def dinov2_vitb14(**kwargs):
"""
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_base", **kwargs)
def dinov2_vitl14(**kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_large", **kwargs)
def dinov2_vitg14(**kwargs):
"""
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_giant2",
ffn_layer="swiglufused",
**kwargs,
)
def dinov2_vits14_reg(**kwargs):
"""
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_small",
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitb14_reg(**kwargs):
"""
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_base",
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitl14_reg(**kwargs):
"""
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_large",
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitg14_reg(**kwargs):
"""
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_giant2",
ffn_layer="swiglufused",
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
================================================
FILE: nit/models/nvidia_radio/radio/dual_hybrid_vit.py
================================================
from logging import getLogger
from typing import Tuple
import torch
from torch import nn
from torch.nn import functional as F
from timm.models import register_model
from timm.models import vision_transformer as tvit
from timm.models import convnext as tconv
from einops import rearrange
from . import extra_timm_models as et
class Fuser(nn.Module):
def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
super().__init__()
self.gated = gated
mid_dim = max(src_dim, tgt_dim) * 2
self.fwd = nn.Sequential(
nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
nn.GELU(),
nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
)
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
if src.ndim == 3:
shape = tgt.shape[-2:]
else:
shape = src.shape[-2:]
nd = shape[0] * shape[1]
if src.ndim == 3:
src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
if tgt.ndim == 3:
tgt_pre = tgt[:, :-nd]
tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
else:
tgt_pre = None
pred = self.fwd(src)
if self.gated:
g, pred = torch.chunk(pred, 2, dim=1)
g = F.sigmoid(g)
pred = g * pred
tgt = tgt + pred
if tgt_pre is not None:
tgt = rearrange(tgt, 'b c h w -> b (h w) c')
tgt = torch.cat([tgt_pre, tgt], dim=1)
return tgt
class AttnDownsample(nn.Module):
def __init__(self, dim: int, window_size: int, num_heads: int = 16):
super().__init__()
self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
self.kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim)
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
ntok = twod_shape[0] * twod_shape[1]
x_pre = x[:, :-ntok]
B = x.shape[0]
ds_hw = tuple(s // self.window_size for s in twod_shape)
x_spat = rearrange(
x[:, -ntok:],
'b (h d1 w d2) c -> (b h w) (d1 d2) c',
h=ds_hw[0], w=ds_hw[1],
d1=self.window_size, d2=self.window_size,
)
B, N, C = x_spat.shape
k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q = (self.q * self.scale).expand(B, -1, -1, -1)
attn = q @ k.transpose(-2, -1)
attn = F.softmax(attn, dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, C)
x = self.proj(x)
x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
x = torch.cat([x_pre, x], dim=1)
return x
class HybridModel(nn.Module):
def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
concatenate: bool = False, **kwargs):
super().__init__()
self.conv = conv
self.vit = vit
self.concatenate = concatenate
conv.stages = nn.ModuleList(conv.stages)
vit.blocks = nn.ModuleList(vit.blocks)
self._half_vit_idx = len(vit.blocks) // 2 + 1
self._half_conv_idx = None
x = torch.empty(1, 3, 256, 256)
x = self.conv.stem(x)
for i in range(len(conv.stages)):
x = conv.stages[i](x)
if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
self._half_conv_idx = i + 1
half_conv_dim = x.shape[1]
final_conv_dim = x.shape[1]
self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
if not concatenate:
self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
self.final_block = tvit.Block(embed_dim, num_heads=16)
self.embed_dim = embed_dim
@property
def patch_size(self):
return 32
@property
def no_fsdp_wrap_types(self):
return {tvit.VisionTransformer, tconv.ConvNeXt}
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_features(x)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
y_vit = self.vit.patch_generator(x)
for i in range(self._half_vit_idx):
y_vit = self.vit.blocks[i](y_vit)
y_conv = self.conv.stem(x)
for i in range(self._half_conv_idx):
y_conv = self.conv.stages[i](y_conv)
y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
for i in range(self._half_vit_idx, len(self.vit.blocks)):
y_vit = self.vit.blocks[i](y_vit)
for i in range(self._half_conv_idx, len(self.conv.stages)):
y_conv = self.conv.stages[i](y_conv)
if self.concatenate:
y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
# Average pool across the board, and replicate for each cls/register token
conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
y_conv = torch.cat([conv_summary, y_conv], dim=1)
y = torch.cat([y_vit, y_conv], dim=2)
else:
y = self.final_fuse(y_conv, y_vit)
y = self.final_block(y)
summary = y[:, :self.vit.patch_generator.num_cls_tokens]
features = y[:, self.vit.patch_generator.num_cls_patches:]
return summary, features
@register_model
def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
cfg = dict(num_classes=0, **kwargs)
conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
@register_model
def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
cfg = dict(num_classes=0, **kwargs)
conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
@register_model
def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
cfg = dict(num_classes=0, **kwargs)
conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
================================================
FILE: nit/models/nvidia_radio/radio/enable_cpe_support.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import List, Optional, Set, Tuple, Union
from types import MethodType
import torch
from torch import nn
from timm.models import VisionTransformer, checkpoint_seq
from timm.models.vision_transformer import Attention, Block
from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
from .extra_models import DinoWrapper
from .vit_patch_generator import ViTPatchGenerator
from .forward_intermediates import forward_intermediates
from .dual_hybrid_vit import HybridModel
from flash_attn import flash_attn_varlen_func
def _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
N, C = x.shape
qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
x = flash_attn_varlen_func(
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
).reshape(N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_seqlens)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
def _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor]) -> torch.Tensor:
device = images[0].device
x = []
seqlens = []
for image in images:
# image: [1, c, H, W] -> x: [n_cls+h*w, D], h=H/p and w=W/p
_image = self.patch_generator(image).squeeze(0)
x.append(_image)
seqlens.append(_image.shape[0])
x = torch.cat(x, dim=0)
seqlens = torch.tensor(seqlens, device=device, dtype=torch.int)
cu_seqlens = torch.cat([
torch.tensor([0], device=device, dtype=torch.int32),
torch.cumsum(seqlens, dim=0, dtype=torch.int32)
])
if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
for block in self.blocks:
x = checkpoint_seq(block, x, cu_seqlens)
else:
for block in self.blocks:
x = block(x, cu_seqlens)
x = self.norm(x)
return x, cu_seqlens
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
x = self.patch_generator(x)
if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def _take_indices(
num_blocks: int,
n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def _forward_intermediates_cpe(
self,
x: torch.Tensor,
norm: bool = False,
**kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
return forward_intermediates(
self,
patch_extractor=self.patch_generator,
num_summary_tokens=self.patch_generator.num_skip,
num_cls_tokens=self.patch_generator.num_cls_tokens,
norm=self.norm if norm else lambda y: y,
x=x,
**kwargs,
)
def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor:
y = _forward_cpe(self.inner, x)
return y[:, 0], y[:, self.num_summary_tokens:]
def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):
return _forward_intermediates_cpe(self.inner, *args, **kwargs)
def _enable_cpe_for_timm_vit(model: VisionTransformer,
max_img_size: Union[int, Tuple[int, int]] = 1024,
num_cls_tokens: int = 1,
pos_dropout: float = 0.1,
register_multiple: int = Optional[None],
num_registers: int = Optional[None],
support_packing: bool = False,
):
if not isinstance(model, VisionTransformer):
raise ValueError("CPE only support for VisionTransformer models!")
patch_size = model.patch_embed.patch_size[0]
embed_dim = model.embed_dim
input_dims = model.patch_embed.img_size
normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
cls_token = model.cls_token is not None
max_img_size = int(round(max_img_size / patch_size) * patch_size)
patch_generator = ViTPatchGenerator(
patch_size=patch_size,
embed_dim=embed_dim,
input_dims=input_dims,
normalize_patches=normalize_patches,
cls_token=cls_token,
max_input_dims=max_img_size,
pos_dropout=pos_dropout,
num_cls_tokens=num_cls_tokens,
register_multiple=register_multiple,
num_registers=num_registers,
)
model.patch_generator = patch_generator
model.patch_embed = None
model.cls_token = None
model.pos_embed = None
model.pos_drop = None
model.patch_size = patch_size
model.num_cls_tokens = num_cls_tokens
model.num_registers = patch_generator.num_registers
model.forward_features = MethodType(_forward_cpe, model)
model.forward_intermediates = MethodType(_forward_intermediates_cpe, model)
if support_packing:
model.forward_features = MethodType(_forward_cpe_pack, model)
for block in model.blocks:
block.forward = MethodType(_block_forward_pack, block)
block.attn.forward = MethodType(_attn_forward_pack, block.attn)
def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,
max_img_size: Union[int, Tuple[int, int]] = 1024,
num_cls_tokens: int = 1,
pos_dropout: float = 0.1,
register_multiple: int = Optional[None],
num_registers: int = Optional[None],
):
patch_size = model.patch_size
embed_dim = model.embed_dim
input_dims = model.inner.patch_embed.patches_resolution
normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity)
cls_token = True
max_img_size = int(round(max_img_size / patch_size) * patch_size)
patch_generator = ViTPatchGenerator(
patch_size=patch_size,
embed_dim=embed_dim,
input_dims=input_dims,
normalize_patches=normalize_patches,
cls_token=cls_token,
max_input_dims=max_img_size,
pos_dropout=pos_dropout,
num_cls_tokens=num_cls_tokens,
register_multiple=register_multiple,
num_registers=num_registers,
patch_bias=True,
)
inner = model.inner
inner.patch_generator = patch_generator
inner.patch_embed = None
inner.cls_token = None
inner.pos_embed = None
inner.register_tokens = None
inner.patch_size = patch_size
model.forward_features = MethodType(_forward_cpe_dinov2, model)
model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model)
def enable_cpe(model: nn.Module,
*args,
**kwargs,
):
if isinstance(model, VisionTransformer):
_enable_cpe_for_timm_vit(model, *args, **kwargs)
elif isinstance(model, DinoWrapper):
_enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
elif isinstance(model, HybridModel):
_enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
else:
raise ValueError(f'CPE not supported for this model type: {type(model)}')
================================================
FILE: nit/models/nvidia_radio/radio/enable_damp.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from logging import getLogger
import math
import os
from typing import Dict, List, Optional, Union, Tuple
from types import MethodType
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import parametrize
# For now, don't do anything
class DAMP(nn.Identity):
def __init__(self, std: float):
super().__init__()
self.std = std
def enable_damp(model: nn.Module, std: float):
if isinstance(model, (list, tuple)):
for m in model:
enable_damp(m, std)
return
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
parametrize.register_parametrization(module, 'weight', DAMP(std))
def configure_damp_from_args(model: nn.Module, args):
damp = getattr(args, 'damp', None)
if damp:
enable_damp(model, damp)
================================================
FILE: nit/models/nvidia_radio/radio/enable_spectral_reparam.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from logging import getLogger
import math
import os
from typing import Dict, List, Optional, Union, Tuple
from types import MethodType
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import parametrize
from torch.nn.utils.parametrizations import _SpectralNorm
from timm.models.vision_transformer import Attention, Mlp
_EPS = 1e-5
class _SNReweight(_SpectralNorm):
def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
super().__init__(weight, *args, **kwargs)
self.alpha = alpha
self.version = version
self.register_buffer('_sn_version', torch.tensor(version))
if init_norm_to_current:
# This will set the numerator to match the denominator, which should preserve the original values
init_scale = self._get_sigma(weight, n_power_iterations=20).item()
else:
init_scale = 1.0
if version == 1:
init_value = init_scale
elif version == 2:
t = init_scale - alpha
if t < _EPS:
getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
t = _EPS
init_value = math.log(math.exp(t) - 1)
else:
raise ValueError(f'Unsupported version: {version}')
# Make 2D so that weight decay gets applied
self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))
# Re-implementing this because we need to make division by sigma safe
def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor:
if not n_power_iterations:
n_power_iterations = self.n_power_iterations
if weight.ndim == 1:
# Faster and more exact path, no need to approximate anything
sigma = weight.norm()
else:
weight_mat = self._reshape_weight_to_matrix(weight)
if self.training:
self._power_method(weight_mat, n_power_iterations)
# See above on why we need to clone
u = self._u.clone(memory_format=torch.contiguous_format)
v = self._v.clone(memory_format=torch.contiguous_format)
# The proper way of computing this should be through F.bilinear, but
# it seems to have some efficiency issues:
# https://github.com/pytorch/pytorch/issues/58093
sigma = torch.dot(u, torch.mv(weight_mat, v))
return sigma + self.eps
def forward(self, weight: torch.Tensor, *args, **kwargs):
dtype = weight.dtype
sigma = self._get_sigma(weight, *args, **kwargs)
if self.version == 1:
scale = self.scale
elif self.version == 2:
scale = F.softplus(self.scale) + self.alpha
else:
raise ValueError(f'Unsupported version: {self.version}')
scale = scale.float() / sigma.float()
y = weight * scale
if dtype in (torch.float16, torch.bfloat16):
y = y.to(dtype)
return y
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
version_key = f'{prefix}_sn_version'
if version_key not in state_dict:
self.version = 1
state_dict[version_key] = torch.tensor(1)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
class _ChunkedSNReweight(nn.Module):
def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs):
super().__init__()
self.num_chunks = num_chunks
parts = weight.split(weight.shape[0] // num_chunks, dim=0)
self.parts = nn.ModuleList([
_SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs)
for p in parts
])
def forward(self, weight: torch.Tensor, *args, **kwargs):
parts = weight.split(weight.shape[0] // self.num_chunks, dim=0)
parts = [
fn(p)
for fn, p in zip(self.parts, parts)
]
return torch.cat(parts, dim=0)
class _AttnSNReweight(_ChunkedSNReweight):
def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs)
if not renorm_values:
self.parts[2] = nn.Identity()
def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
n_power_iterations: int = 1,
eps: float = 1e-6,
init_norm_to_current: bool = False,
renorm_values: bool = True,
renorm_mlp: bool = True,
state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
if isinstance(model, (list, tuple)):
for i, sub in enumerate(model):
sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance
enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps,
init_norm_to_current=init_norm_to_current, renorm_values=renorm_values,
renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd)
return
print('Enabling spectral reparametrization')
args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current)
visited_prefixes = set()
def is_guidance_parametrized(name: str):
if state_dict_guidance is None:
return True
p_name = f'{name}.parametrizations'
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
return is_prm
def parametrize_linear(linear: nn.Linear):
parametrize.register_parametrization(
linear,
'weight',
_SNReweight(linear.weight, **args)
)
for name, mod in model.named_modules():
pref = '.'.join(name.split('.')[:-1])
if pref in visited_prefixes:
continue
if isinstance(mod, Attention) or name.endswith('.attn'):
if is_guidance_parametrized(f'{name}.qkv'):
parametrize.register_parametrization(
mod.qkv,
'weight',
_AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args),
)
if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'):
parametrize_linear(mod.proj)
visited_prefixes.add(name)
elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'):
if is_guidance_parametrized(f'{name}.w12'):
parametrize.register_parametrization(
mod.w12,
'weight',
_ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args),
)
if is_guidance_parametrized(f'{name}.w3'):
parametrize_linear(mod.w3)
visited_prefixes.add(name)
elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name):
parametrize_linear(mod)
def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
spectral_reparam = getattr(args, 'spectral_reparam', False)
if isinstance(spectral_reparam, bool) and spectral_reparam:
enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance)
elif isinstance(spectral_reparam, dict):
enable_spectral_reparam(
model,
n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
eps=spectral_reparam.get('eps', 1e-12),
init_norm_to_current=True,
state_dict_guidance=state_dict_guidance,
)
def disable_spectral_reparam(model: nn.Module):
print('Disabling spectral reparametrization')
for name, mod in model.named_modules():
if parametrize.is_parametrized(mod):
parametrize.remove_parametrizations(mod, 'weight')
pass
if __name__ == '__main__':
import argparse
from . import radio_model as create_model
parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')
args = parser.parse_args()
if not args.output:
chk_dir, chk_name = os.path.split(args.checkpoint)
args.output = os.path.join(chk_dir, f'clean_{chk_name}')
print(f'Set output to "{args.output}"')
chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)
model = create_model.create_model_from_args(chk['args'])
key = 'base_model.'
mod_state = dict()
extra_state = dict()
for k, v in chk['state_dict'].items():
if k.startswith(key):
mod_state[k[len(key):]] = v
else:
extra_state[k] = v
chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
print(chk_load_info)
if chk['args'].spectral_reparam:
disable_spectral_reparam(model)
if hasattr(chk['args'], 'dtype'):
model.to(dtype=chk['args'].dtype)
mod_state = model.state_dict()
final_state = dict()
final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
final_state.update(extra_state)
chk['state_dict'] = final_state
chk['args'].spectral_reparam = False
if args.release:
chk = {
'arch': chk['arch'],
'epoch': chk['epoch'],
'state_dict': chk['state_dict'],
'args': chk['args'],
}
torch.save(chk, args.output)
pass
================================================
FILE: nit/models/nvidia_radio/radio/eradio_model.py
================================================
#!/usr/bin/env python3
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
# E-RADIO model from
# Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
# based on FasterViT, Swin Transformer, YOLOv8
# FasterViT:
# Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
import timm
import torch
import torch.nn as nn
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
import numpy as np
import torch.nn.functional as F
import math
import warnings
#######################
## Codebase from YOLOv8
## BEGINNING
#######################
class C2f(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
"""From YOLOv8 codebase"""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
if drop_path is None:
drop_path = [0.0] * n
self.c = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
"""Forward pass using split() instead of chunk()."""
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
class Bottleneck(nn.Module):
"""Standard bottleneck."""
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, k[0], 1)
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
self.add = shortcut and c1 == c2
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
"""'forward()' applies the YOLOv5 FPN to input data."""
return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
class Conv(nn.Module):
"""Modified to support layer fusion"""
default_act = nn.SiLU() # default activation
def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
super().__init__()
self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
if 1:
self.bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
return x
@torch.no_grad()
def switch_to_deploy(self):
# return 1
if not isinstance(self.bn, nn.Identity):
c, bn = self.conv, self.bn
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
self.conv.weight.data.copy_(w)
self.conv.bias = nn.Parameter(b)
self.bn = nn.Identity()
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
#######################
## Codebase from YOLOv8
## END
#######################
def pixel_unshuffle(data, factor=2):
# performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
B, C, H, W = data.shape
return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
class SwiGLU(nn.Module):
# should be more advanced, but doesnt improve results so far
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
def window_partition(x, window_size):
"""
Function for partitioning image into windows and later do windowed attention
Args:
x: (B, C, H, W)
window_size: window size
Returns:
windows - local window features (num_windows*B, window_size*window_size, C)
(Hp, Wp) - the size of the padded image
"""
B, C, H, W = x.shape
if window_size == 0 or (window_size==H and window_size==W):
windows = x.flatten(2).transpose(1, 2)
Hp, Wp = H, W
else:
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
return windows, (Hp, Wp)
class Conv2d_BN(nn.Module):
'''
Conv2d + BN layer with folding capability to speed up inference
Can be merged with Conv() function with additional arguments
'''
def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
super().__init__()
self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
if 1:
self.bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
def forward(self,x):
x = self.conv(x)
x = self.bn(x)
return x
@torch.no_grad()
def switch_to_deploy(self):
if not isinstance(self.bn, nn.Identity):
c, bn = self.conv, self.bn
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
self.conv.weight.data.copy_(w)
self.conv.bias = nn.Parameter(b)
self.bn = nn.Identity()
def window_reverse(windows, window_size, H, W, pad_hw):
"""
Windows to the full feature map
Args:
windows: local window features (num_windows*B, window_size, window_size, C)
window_size: Window size
H: Height of image
W: Width of image
pad_w - a tuple of image passing used in windowing step
Returns:
x: (B, C, H, W)
"""
# print(f"window_reverse, windows.shape {windows.shape}")
Hp, Wp = pad_hw
if window_size == 0 or (window_size==H and window_size==W):
B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
x = windows.transpose(1, 2).view(B, -1, H, W)
else:
B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
if Hp > H or Wp > W:
x = x[:, :, :H, :W, ].contiguous()
return x
class PosEmbMLPSwinv2D(nn.Module):
"""
2D positional embedding from Swin Transformer v2
Added functionality to store the positional embedding in the model and not recompute it every time
"""
def __init__(
self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,
):
super().__init__()
self.window_size = window_size
self.num_heads = num_heads
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, cpb_mlp_hidden, bias=True),
nn.ReLU(inplace=True),
nn.Linear(cpb_mlp_hidden, num_heads, bias=False),
)
self.grid_exists = False
self.seq_length = seq_length
self.deploy = False
self.num_heads = num_heads
self.no_log = no_log
self.pretrained_window_size = pretrained_window_size
self.relative_bias_window_size = window_size
relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,
pretrained_window_size, seq_length,
no_log)
self.register_buffer("relative_coords_table", relative_coords_table)
self.register_buffer("relative_position_index", relative_position_index)
self.register_buffer("relative_bias", relative_bias) # for EMA
def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
# as in separate function to support window size chage after model weights loading
relative_coords_h = torch.arange(
-(window_size[0] - 1), window_size[0], dtype=torch.float32
)
relative_coords_w = torch.arange(
-(window_size[1] - 1), window_size[1], dtype=torch.float32
)
relative_coords_table = (
torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
.permute(1, 2, 0)
.contiguous()
.unsqueeze(0)
) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
else:
relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
if not no_log:
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = (
torch.sign(relative_coords_table)
* torch.log2(torch.abs(relative_coords_table) + 1.0)
/ np.log2(8)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0
).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
self.relative_bias_window_size = window_size
return relative_coords_table, relative_position_index, relative_bias
def switch_to_deploy(self):
self.deploy = True
self.grid_exists = True
def forward(self, input_tensor):
# for efficiency, we want this forward to be folded into a single operation (sum)
# if resolution stays the same, then we dont need to recompute MLP layers
if not self.deploy or self.training:
self.grid_exists = False
#compare if all elements in self.window_size list match those in self.relative_bias_window_size
if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):
relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,
self.pretrained_window_size, self.seq_length,
self.no_log)
self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)
self.relative_position_index = relative_position_index.to(self.relative_position_index.device)
self.relative_bias = relative_bias.to(self.relative_bias.device)
if self.deploy and self.grid_exists:
input_tensor = input_tensor + self.relative_bias
return input_tensor
if 1:
self.grid_exists = True
relative_position_bias_table = self.cpb_mlp(
self.relative_coords_table
).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
self.relative_bias = relative_position_bias.unsqueeze(0)
input_tensor = input_tensor + self.relative_bias
return input_tensor
class GRAAttentionBlock(nn.Module):
def __init__(self, window_size, dim_in, dim_out,
num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
norm_layer=nn.LayerNorm, layer_scale=None,
use_swiglu=True,
subsample_ratio=1, dim_ratio=1, conv_base=False,
do_windowing=True, multi_query=False, use_shift=0,
cpb_mlp_hidden=512, conv_groups_ratio=0):
'''
Global Resolution Attention Block , see README for details
Attention with subsampling to get a bigger receptive field for attention
conv_base - use conv2d instead of avgpool2d for downsample / upsample
'''
super().__init__()
self.shift_size=window_size//2 if use_shift else 0
self.do_windowing = do_windowing
self.subsample_ratio = subsample_ratio
if do_windowing:
if conv_base:
self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
self.downsample_mixer = nn.Identity()
self.upsample_mixer = nn.Identity()
self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
else:
self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
# in case there is no downsampling conv we want to have it separately
# will help with information propagation between windows
if subsample_ratio == 1:
# conv_groups_ratio=0
self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
# self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
# self.pre_conv_act = nn.ReLU6()
#for simplicity:
self.pre_conv_act = nn.Identity()
if conv_groups_ratio == -1:
self.pre_conv = nn.Identity()
self.pre_conv_act = nn.Identity()
self.window_size = window_size
self.norm1 = norm_layer(dim_in)
self.attn = WindowAttention(
dim_in,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
resolution=window_size,
seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
### mlp layer
mlp_ratio = 4
self.norm2 = norm_layer(dim_in)
mlp_hidden_dim = int(dim_in * mlp_ratio)
activation = nn.GELU if not use_swiglu else SwiGLU
mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
skip_connection = x
attn_mask = None
# in case there is no downsampling conv we want to have it separately
# will help with information propagation
if self.subsample_ratio == 1:
x = self.pre_conv_act(self.pre_conv(x)) + skip_connection
if self.do_windowing:
# performing windowing if required
x = self.downsample_op(x)
x = self.downsample_mixer(x)
if self.window_size>0:
H, W = x.shape[2], x.shape[3]
if self.shift_size > 0 and H>self.window_size and W>self.window_size:
# @swin like cyclic shift, doesnt show better performance
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
x, pad_hw = window_partition(x, self.window_size)
if self.shift_size > 0 and H>self.window_size and W>self.window_size:
# set atten matrix to have -100 and the top right square
# attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0
# calculate attention mask for SW-MSA
# not used in final version, can be useful for some cases especially for high res
H, W = pad_hw
img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
img_mask = img_mask.transpose(1,2).transpose(1,3)
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
# window attention
x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W
# mlp layer
x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
if self.do_windowing:
if self.window_size > 0:
x = window_reverse(x, self.window_size, H, W, pad_hw)
# reverse cyclic shift
if self.shift_size > 0 and H>self.window_size and W>self.window_size:
# @swin like cyclic shift, not tested
x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
x = self.upsample_mixer(x)
x = self.upsample_op(x)
if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode="reflect")
# need to add skip connection because downsampling and upsampling will break residual connection
# 0.5 is needed to make sure that the skip connection is not too strong
# in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
x = 0.5 * x + 0.5 * skip_connection
return x
class MultiResolutionAttention(nn.Module):
"""
MultiResolutionAttention (MRA) module
The idea is to use multiple attention blocks with different resolution
Feature maps are downsampled / upsampled for each attention block on different blocks
Every attention block supports windowing
"""
def __init__(self, window_size, sr_ratio,
dim, dim_ratio, num_heads,
do_windowing=True,
layer_scale=1e-5, norm_layer=nn.LayerNorm,
drop_path = 0, qkv_bias=False, qk_scale=1.0,
use_swiglu=True, multi_query=False, conv_base=False,
use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None:
"""
Args:
input_resolution: input image resolution
window_size: window size
compression_ratio: compression ratio
max_depth: maximum depth of the GRA module
use_shift: do window shifting
"""
super().__init__()
depth = len(sr_ratio)
self.attention_blocks = nn.ModuleList()
for i in range(depth):
subsample_ratio = sr_ratio[i]
if len(window_size) > i:
window_size_local = window_size[i]
else:
window_size_local = window_size[0]
self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
dim_in=dim, dim_out=dim, num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
layer_scale=layer_scale, drop_path=drop_path,
use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base,
use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio),
)
def forward(self, x):
for attention_block in self.attention_blocks:
x = attention_block(x)
return x
class Mlp(nn.Module):
"""
Multi-Layer Perceptron (MLP) block
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
use_swiglu=True,
drop=0.):
"""
Args:
in_features: input features dimension.
hidden_features: hidden features dimension.
out_features: output features dimension.
act_layer: activation function.
drop: dropout rate.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
def forward(self, x):
x_size = x.size()
x = x.view(-1, x_size[-1])
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = x.view(x_size)
return x
class Downsample(nn.Module):
"""
Down-sampling block
Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
"""
def __init__(self,
dim,
shuffle = False,
):
"""
Args:
dim: feature size dimension.
shuffle: idea with
keep_dim: bool argument for maintaining the resolution.
"""
super().__init__()
dim_out = 2 * dim
if shuffle:
self.norm = lambda x: pixel_unshuffle(x, factor=2)
self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
# pixel unshuffleging works well but doesnt provide any speedup
else:
# removed layer norm for better, in this formulation we are getting 10% better speed
# LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
# therefore we remove it compared to the original implementation in FasterViT
self.norm = nn.Identity()
self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
def forward(self, x):
x = self.norm(x)
x = self.reduction(x)
return x
class PatchEmbed(nn.Module):
"""
Patch embedding block
Used to convert image into an initial set of feature maps with lower resolution
"""
def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
"""
Args:
in_chans: number of input channels.
in_dim: intermediate feature size dimension to speed up stem.
dim: final stem channel number
shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field
"""
super().__init__()
# shuffle_down = False
if not shuffle_down:
self.proj = nn.Identity()
self.conv_down = nn.Sequential(
Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
nn.ReLU(),
Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
nn.ReLU()
)
else:
self.proj = lambda x: pixel_unshuffle(x, factor=4)
self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
nn.ReLU(),
)
def forward(self, x):
x = self.proj(x)
x = self.conv_down(x)
return x
class ConvBlock(nn.Module):
"""
Convolutional block, used in first couple of stages
Experimented with plan resnet-18 like modules, they are the best in terms of throughput
Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
"""
def __init__(self, dim,
drop_path=0.,
layer_scale=None,
kernel_size=3,
):
super().__init__()
self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.act1 = nn.GELU()
self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
self.layer_scale = layer_scale
if layer_scale is not None and type(layer_scale) in [int, float]:
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
self.layer_scale = True
else:
self.layer_scale = False
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
if self.layer_scale:
x = x * self.gamma.view(1, -1, 1, 1)
x = input + self.drop_path(x)
return x
class WindowAttention(nn.Module):
# Windowed Attention from SwinV2
# use a MLP trick to deal with various input image resolutions, then fold it to improve speed
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512):
# taken from EdgeViT and tweaked with attention bias.
super().__init__()
if not dim_out: dim_out = dim
self.shift_size = shift_size
self.multi_query = multi_query
self.num_heads = num_heads
head_dim = dim // num_heads
self.head_dim = dim // num_heads
self.dim_internal = dim
self.scale = qk_scale or head_dim ** -0.5
if not multi_query:
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
else:
self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
self.proj = nn.Linear(dim, dim_out, bias=False)
# attention positional bias
self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
pretrained_window_size=[resolution, resolution],
num_heads=num_heads,
seq_length=seq_length,
cpb_mlp_hidden=cpb_mlp_hidden)
self.resolution = resolution
def forward(self, x, attn_mask = None):
B, N, C = x.shape
if not self.multi_query:
qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
else:
qkv = self.qkv(x)
(q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = self.pos_emb_funct(attn)
#add window shift
if attn_mask is not None:
nW = attn_mask.shape[0]
attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
x = self.proj(x)
return x
class ERADIOLayer(nn.Module):
"""
E-RADIO Layer
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
conv=False,
downsample=True,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
norm_layer=nn.LayerNorm,
drop_path=0.,
layer_scale=None,
layer_scale_conv=None,
sr_dim_ratio=1,
sr_ratio=1,
multi_query=False,
use_swiglu=True,
yolo_arch=False,
downsample_shuffle=False,
conv_base=False,
use_shift=False,
cpb_mlp_hidden=512,
conv_groups_ratio=0,
verbose: bool = True,
):
"""
Args:
dim: feature size dimension.
depth: number of layers in each stage.
input_resolution: input image resolution.
window_size: window size in each stage.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution)
conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention
"""
super().__init__()
self.conv = conv
self.yolo_arch=False
self.verbose = verbose
if conv:
if not yolo_arch:
self.blocks = nn.ModuleList([
ConvBlock(dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale_conv)
for i in range(depth)])
self.blocks = nn.Sequential(*self.blocks)
else:
self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
self.yolo_arch=True
else:
if not isinstance(window_size, list): window_size = [window_size]
self.window_size = window_size[0]
self.do_single_windowing = True
if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
self.sr_ratio = sr_ratio
if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
self.do_single_windowing = False
do_windowing = True
else:
self.do_single_windowing = True
do_windowing = False
#for v2_2
if conv_groups_ratio != -1:
self.do_single_windowing = False
do_windowing = True
self.blocks = nn.ModuleList()
for i in range(depth):
self.blocks.append(
MultiResolutionAttention(window_size=window_size,
sr_ratio=sr_ratio,
dim=dim,
dim_ratio = sr_dim_ratio,
num_heads=num_heads,
norm_layer=norm_layer,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
layer_scale=layer_scale,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
use_swiglu=use_swiglu,
do_windowing=do_windowing,
multi_query=multi_query,
conv_base=conv_base,
cpb_mlp_hidden=cpb_mlp_hidden,
use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True ,
conv_groups_ratio=conv_groups_ratio,
))
self.blocks = nn.Sequential(*self.blocks)
self.transformer = not conv
self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
def forward(self, x):
B, C, H, W = x.shape
# do padding for transforemr
interpolate = True
if self.transformer and interpolate:
# Windowed Attention will split feature map into windows with the size of window_size x window_size
# if the resolution is not divisible by window_size, we need to interpolate the feature map
# can be done via padding, but doing so after training hurts the model performance.
# interpolation affects the performance as well, but not as much as padding
if isinstance(self.window_size, list) or isinstance(self.window_size, tuple):
current_max_window_size = max(self.window_size)
else:
current_max_window_size = self.window_size
max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio])
if H % max_window_size != 0 or W % max_window_size != 0:
new_h = int(np.ceil(H/max_window_size)*max_window_size)
new_w = int(np.ceil(W/max_window_size)*max_window_size)
x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
if self.verbose:
warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
if self.transformer and self.do_single_windowing:
H, W = x.shape[2], x.shape[3]
x, pad_hw = window_partition(x, self.window_size)
#run main blocks
x = self.blocks(x)
if self.transformer and self.do_single_windowing:
x = window_reverse(x, self.window_size, H, W, pad_hw)
if self.transformer and interpolate:
#lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution.
x = F.interpolate(x, size=(H, W), mode='nearest')
if self.downsample is None:
return x, x
return self.downsample(x), x # changing to output pre downsampled features
class InterpolateLayer(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest'):
super(InterpolateLayer, self).__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
class HiResNeck(nn.Module):
"""
The block is used to output dense features from all stages
Otherwise, by default, only the last stage features are returned with E-RADIO
"""
def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
'''
Hi Resolution neck to support output of high res features that are useful for dense tasks.
depths - total number of layers in the base model
neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc.
earlier layers result in higher resolution features at the cost of compute
full_features_head_dim - number of channels in the dense features head
'''
super().__init__()
# create feature projection layers for segmentation output
self.neck_features_proj = nn.ModuleList()
self.neck_start_stage = neck_start_stage
upsample_ratio = 1
for i in range(len(depths)):
level_n_features_output = int(dim * 2 ** i)
if self.neck_start_stage > i: continue
if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
feature_projection = nn.Sequential()
if False:
feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
else:
# B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio
# print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output)
feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest'))
feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output))
feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
# B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio
feature_projection.add_module("conv2", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0))
else:
feature_projection = nn.Sequential()
self.neck_features_proj.append(feature_projection)
if i>0 and downsample_enabled[i]:
upsample_ratio *= 2
def forward(self, x, il_level=-1, full_features=None):
if self.neck_start_stage > il_level:
return full_features
if full_features is None:
full_features = self.neck_features_proj[il_level - self.neck_start_stage](x)
else:
#upsample torch tensor x to match full_features size, and add to full_features
feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
full_features = full_features + feature_projection
return full_features
class ERADIO(nn.Module):
"""
Efficient RADIO
"""
def __init__(self,
dim,
in_dim,
depths,
window_size,
mlp_ratio,
num_heads,
drop_path_rate=0.2,
in_chans=3,
num_classes=1000,
qkv_bias=False,
qk_scale=None,
layer_scale=None,
layer_scale_conv=None,
layer_norm_last=False,
sr_ratio = [1, 1, 1, 1],
max_depth = -1,
conv_base=False,
use_swiglu=False,
multi_query=False,
norm_layer=nn.LayerNorm,
drop_uniform=False,
yolo_arch=False,
shuffle_down=False,
downsample_shuffle=False,
return_full_features=False,
full_features_head_dim=128,
neck_start_stage=1,
use_neck=False,
use_shift=False,
cpb_mlp_hidden=512,
conv_groups_ratio=0,
verbose: bool = False,
**kwargs):
"""
Args:
dim: feature size dimension.
depths: number of layers in each stage.
window_size: window size in each stage.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
drop_path_rate: drop path rate.
in_chans: number of input channels.
num_classes: number of classes.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
return_full_features: output dense features as well as logits
full_features_head_dim: number of channels in the dense features head
neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0
for 224 resolution, the output of the stage before downsample:
stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
use_neck: even for summarization embedding use neck
use_shift: SWIN like window shifting but without masking attention
conv_groups_ratio: will be used for conv blocks where there is no multires attention,
if 0 then normal conv,
if 1 then channels are independent,
if -1 then no conv at all
"""
super().__init__()
num_features = int(dim * 2 ** (len(depths) - 1))
self.num_classes = num_classes
self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
# set return_full_features true if we want to return full features from all stages
self.return_full_features = return_full_features
self.use_neck = use_neck
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
if drop_uniform:
dpr = [drop_path_rate for x in range(sum(depths))]
if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
self.levels = nn.ModuleList()
for i in range(len(depths)):
conv = True if (i == 0 or i == 1) else False
level = ERADIOLayer(dim=int(dim * 2 ** i),
depth=depths[i],
num_heads=num_heads[i],
window_size=window_size[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
conv=conv,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
downsample=(i < len(depths) - 1),
layer_scale=layer_scale,
layer_scale_conv=layer_scale_conv,
sr_ratio=sr_ratio[i],
use_swiglu=use_swiglu,
multi_query=multi_query,
norm_layer=norm_layer,
yolo_arch=yolo_arch,
downsample_shuffle=downsample_shuffle,
conv_base=conv_base,
cpb_mlp_hidden=cpb_mlp_hidden,
use_shift=use_shift,
conv_groups_ratio=conv_groups_ratio,
verbose=verbose)
self.levels.append(level)
if self.return_full_features or self.use_neck:
#num_heads
downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))]
self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled)
self.switched_to_deploy = False
self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, LayerNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'rpb'}
def forward_features(self, x):
_, _, H, W = x.shape
if H % 32 != 0 or W % 32 != 0:
raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}")
x = self.patch_embed(x)
full_features = None
for il, level in enumerate(self.levels):
x, pre_downsample_x = level(x)
if self.return_full_features or self.use_neck:
full_features = self.high_res_neck(pre_downsample_x, il, full_features)
# x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
x = self.norm(x) # new version for
if not self.return_full_features:
return x, None
return x, full_features
def forward(self, x):
x, full_features = self.forward_features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.head(x)
if full_features is not None:
return x, full_features
return x
def switch_to_deploy(self):
'''
A method to perform model self-compression
merges BN into conv layers
converts MLP relative positional bias into precomputed buffers
'''
if not self.switched_to_deploy:
for level in [self.patch_embed, self.levels, self.head]:
for module in level.modules():
if hasattr(module, 'switch_to_deploy'):
module.switch_to_deploy()
self.switched_to_deploy = True
def change_window_size(self, new_window_size):
"""
E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
especially in cases of uneven partitioning of the feature maps.
E-RADIO allows for the adjustment of the window size after training,
making it adaptable to different input image resolutions.
The recommended values for window size based on input resolution are as follows:
Input Resolution | Window Size
224 | 7
256 | 8
386 | 12
512 | 16
Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
img_res/16/2
for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
Manual way to change resolution -> model.change_window_size(resolution)
"""
window_size = new_window_size
print(f"Setting window size to {window_size}")
for module in self.modules():
if hasattr(module, "window_size"):
# check if tuple or a number
if isinstance(module.window_size, tuple):
if module.window_size[0] != window_size:
module.window_size = (window_size, window_size)
elif isinstance(module.window_size, list):
if module.window_size[0] != window_size:
module.window_size = [window_size, window_size]
else:
module.window_size = window_size
def set_optimal_window_size(self, image_dim, max_window_size = 16):
"""
Using hand picked window size for various resolutions.
E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
especially in cases of uneven partitioning of the feature maps.
E-RADIO allows for the adjustment of the window size after training,
making it adaptable to different input image resolutions.
The recommended values for window size based on input resolution are as follows:
Input Resolution | Window Size
224 | 7
256 | 8
386 | 12
512 | 16
Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
img_res/16/2
for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
Manual way to change resolution -> model.change_window_size(resolution)
"""
# import math
def divisorGenerator(n):
large_divisors = []
for i in range(1, int(math.sqrt(n) + 1)):
if n % i == 0:
yield i
if i*i != n:
large_divisors.append(n / i)
for divisor in reversed(large_divisors):
yield divisor
if isinstance(image_dim, list) or isinstance(image_dim, tuple):
image_dim = min(image_dim)
# we do windowed attention in the 3rd stage for the first time, therefore //16,
# we do subsampled attention with downsample by 2 so need to get //32 actually
# ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc
all_divisors = np.array(list(divisorGenerator(image_dim//32)))
new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
# for image_dim in [128, 224, 256, 384, 512, 768, 1024]:
# all_divisors = np.array(list(divisorGenerator(image_dim//32)))
# new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
# print(f"Setting window size to {new_window_size} for image resolution {image_dim}")
self.change_window_size(new_window_size = new_window_size)
@register_model
def eradio_large_fullres_ws16(pretrained=False, **kwargs):
model = ERADIO(
depths=[3, 3, 5, 5],
num_heads=[2, 4, 8, 16],
window_size=[None, None, [16, 16], 16],
dim=192,
in_dim=64,
mlp_ratio=4,
drop_path_rate=0.0,
sr_ratio=[1, 1, [2, 1], 1],
use_swiglu=False,
yolo_arch=True,
shuffle_down=False,
conv_base=True,
use_neck=True,
full_features_head_dim=1536,
neck_start_stage=2,
**kwargs,
)
if pretrained:
model.load_state_dict(torch.load(pretrained)["state_dict"])
return model
@register_model
def eradio_xxxtiny(pretrained=False, **kwargs): # ,
model = ERADIO(
depths=[1, 3, 4, 5],
num_heads=[2, 4, 8, 16],
window_size=[None, None, [16, 16], 16],
dim=32,
in_dim=32,
mlp_ratio=4,
drop_path_rate=0.0,
sr_ratio=[1, 1, [2, 1], 1],
use_swiglu=False,
yolo_arch=True,
shuffle_down=False,
conv_base=True,
use_neck=True,
full_features_head_dim=256,
neck_start_stage=2,
**kwargs,
)
if pretrained:
model.load_state_dict(torch.load(pretrained))
return model
@register_model
def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
model = ERADIO(depths=[1, 3, 4, 5],
num_heads=[2, 4, 8, 16],
window_size=[None, None, [12, 12], 12],
dim=32,
in_dim=32,
mlp_ratio=4,
drop_path_rate=0.0,
sr_ratio=[1, 1, [2, 1], 1],
use_swiglu=False,
downsample_shuffle=False,
yolo_arch=True,
shuffle_down=False,
cpb_mlp_hidden=64,
use_neck=True,
full_features_head_dim=256,
neck_start_stage=2,
conv_groups_ratio = 1,
**kwargs)
if pretrained:
model.load_state_dict(torch.load(pretrained)["state_dict"])
return model
@register_model
def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
model = ERADIO(depths=[1, 3, 4, 5],
num_heads=[2, 4, 8, 16],
window_size=[None, None, [16, 16], 16],
dim=32,
in_dim=32,
mlp_ratio=4,
drop_path_rate=0.0,
sr_ratio=[1, 1, [2, 1], 1],
use_swiglu=False,
downsample_shuffle=False,
yolo_arch=True,
shuffle_down=False,
cpb_mlp_hidden=64,
use_neck=True,
full_features_head_dim=256,
neck_start_stage=1,
conv_groups_ratio = 1,
**kwargs)
if pretrained:
model.load_state_dict(torch.load(pretrained)["state_dict"])
return model
@register_model
def eradio(pretrained=False, **kwargs):
return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)
================================================
FILE: nit/models/nvidia_radio/radio/extra_models.py
================================================
from distutils.version import LooseVersion
from types import MethodType
from typing import List, Optional, Tuple, Union
import warnings
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.registry import register_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .forward_intermediates import forward_intermediates
from .input_conditioner import InputConditioner
_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
class PaliGemmaWrapper(nn.Module):
def __init__(self, vis_model: nn.Module, embed_dim: int):
super().__init__()
self.vis_model = vis_model
self.embed_dim = embed_dim
@property
def patch_size(self):
return self.vis_model.embeddings.patch_size
@property
def blocks(self):
return self.vis_model.encoder.layers
@property
def embed_dim(self):
return self.vis_model.embeddings.embed_dim
def forward(self, x: torch.Tensor):
outputs = self.vis_model(
x,
return_dict=False,
interpolate_pos_encoding=True,
)
features = outputs[0].to(torch.float32)
summary = features.mean(dim=1)
return summary, features
def forward_features(self, x: torch.Tensor):
return self(x)
def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
if LooseVersion(tx_version) > LooseVersion('4.44.2'):
warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
extra_args = dict()
if dtype is not None:
extra_args['torch_dtype'] = dtype
rev = str(dtype).split('.')[-1]
extra_args['revision'] = rev
model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
vis_model = model.vision_tower.vision_model
vis_model = PaliGemmaWrapper(vis_model, embed_dim)
return vis_model
@register_model
def paligemma_896_student(**kwargs):
model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
return model
def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
x = F.scaled_dot_product_attention(
q, k, v,
is_causal=False,
dropout_p=self.attn_drop.p if self.training else 0.,
scale=self.scale,
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
if cache_dir:
torch.hub.set_dir(cache_dir)
model: nn.Module = torch.hub.load(
'facebookresearch/dinov2',
dino_v2_model,
pretrained=pretrained,
# **kwargs,
)
if _has_torch_sdpa:
for n, m in model.named_modules():
if n.endswith('.attn'):
m.forward = MethodType(dv2_sdpa, m)
return model
class DinoWrapper(nn.Module):
def __init__(self, dino_model: nn.Module):
super().__init__()
self.inner = dino_model
dino_model.blocks = nn.Sequential(*dino_model.blocks)
@property
def embed_dim(self):
return self.inner.embed_dim
@property
def patch_size(self):
return self.inner.patch_size
@property
def num_cls_tokens(self):
return getattr(self.inner, 'num_tokens', 1)
@property
def num_registers(self):
return getattr(self.inner, 'num_register_tokens', 0)
@property
def num_summary_tokens(self):
return self.num_cls_tokens + self.num_registers
@property
def blocks(self):
return self.inner.blocks
def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
parts = self.inner.forward_features(*args, **kwargs)
cls_token = parts['x_norm_clstoken']
features = parts['x_norm_patchtokens']
return cls_token, features
def forward_features(self, x: torch.Tensor):
x = self.inner.prepare_tokens_with_masks(x)
x = self.inner.blocks(x)
x_norm = self.inner.norm(x)
return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
def patchify(self, x: torch.Tensor) -> torch.Tensor:
return self.inner.prepare_tokens_with_masks(x)
def forward_intermediates(self,
x: torch.Tensor,
norm: bool = False,
**kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
return forward_intermediates(
self,
patch_extractor=self.inner.prepare_tokens_with_masks,
num_summary_tokens=self.num_summary_tokens,
num_cls_tokens=self.num_cls_tokens,
norm=self.inner.norm if norm else lambda y: y,
x=x,
**kwargs,
)
def _dino_student(arch: str, **kwargs):
from . import dinov2_arch
factory = getattr(dinov2_arch, arch)
model = factory()
model = DinoWrapper(model)
conditioner = InputConditioner(
input_scale=1.0,
norm_mean=IMAGENET_DEFAULT_MEAN,
norm_std=IMAGENET_DEFAULT_STD,
)
model.input_conditioner = conditioner
return model
@register_model
def dino_v2_l_student(**kwargs):
return _dino_student('dinov2_vitl14_reg', **kwargs)
@register_model
def dino_v2_g_student(**kwargs):
return _dino_student('dinov2_vitg14_reg', **kwargs)
================================================
FILE: nit/models/nvidia_radio/radio/extra_timm_models.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from timm.models import register_model
from timm.models.vision_transformer import (
VisionTransformer,
_create_vision_transformer as _timm_create_vision_transformer,
Mlp,
Block,
LayerScale as TIMMLayerScale,
)
# Import these to also register them
from . import dinov2_arch
@register_model
def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Tiny (Vit-Ti/16)
"""
model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Small (ViT-S/16)
"""
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
)
model = _create_vision_transformer(
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
name = 'vit_large_patch14_reg4_dinov2'
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
)
model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
if pretrained:
# There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
else:
model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
for m in model.modules():
if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
m.norm = nn.LayerNorm(m.fc1.out_features)
return model
@register_model
def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
""" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
if scaled_ln:
_apply_scaled_ln(model)
return model
@register_model
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
return model
def _create_vision_transformer(*args, **kwargs):
model = _timm_create_vision_transformer(*args, **kwargs)
_patch_layer_scale(model)
return model
def _patch_layer_scale(model: VisionTransformer):
def replace_ls(old_ls: TIMMLayerScale):
new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
new_ls.load_state_dict(old_ls.state_dict())
return new_ls
# Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
# other than gamma, so that HFHub doesn't mess with it!
for mod in model.modules():
if isinstance(mod, Block):
if isinstance(mod.ls1, TIMMLayerScale):
mod.ls1 = replace_ls(mod.ls1)
if isinstance(mod.ls2, TIMMLayerScale):
mod.ls2 = replace_ls(mod.ls2)
pass
class ScaledLayerNorm(nn.LayerNorm):
'''
https://arxiv.org/pdf/2502.05795v1
'''
def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
self.load_state_dict(ln_base.state_dict())
self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
def forward(self, x):
y = super().forward(x)
y = y * self.ln_scale
return y
class DyT(nn.Module):
def __init__(self, C: int, init_alpha: float):
super().__init__()
self.alpha = nn.Parameter(torch.full((1,), init_alpha))
self.gamma = nn.Parameter(torch.ones(C))
self.beta = nn.Parameter(torch.zeros(C))
def forward(self, x: torch.Tensor):
x = F.tanh(self.alpha * x)
return self.gamma * x + self.beta
@register_model
def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
return DyT(ln.normalized_shape[0], init_alpha=0.9)
_replace_ln(model, _replace_ln_with_dyt)
return model
def _apply_scaled_ln(model: VisionTransformer):
warnings.warn('Post-LayerNorm scaling activated!')
_replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
def _replace_ln(model: VisionTransformer, fn):
def _inner_replace_ln(block: Block, depth: int, key: str):
prev = getattr(block, key)
if isinstance(prev, nn.LayerNorm):
setattr(block, key, fn(prev, depth=depth))
for i, block in enumerate(model.blocks):
_inner_replace_ln(block, i + 1, 'norm1')
_inner_replace_ln(block, i + 1, 'norm2')
================================================
FILE: nit/models/nvidia_radio/radio/feature_normalizer.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from collections import namedtuple
from typing import NamedTuple, Optional, Tuple
import torch
from torch import nn
def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):
if x.ndim <= 3:
x = x - mean
x = x @ tx.T
elif x.ndim == 4:
x = x - mean.reshape(1, -1, 1, 1)
kernel = tx.reshape(*tx.shape, 1, 1)
x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0)
else:
raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}')
return x
class FeatureNormalizer(nn.Module):
def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
super().__init__()
self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))
self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = _run_kernel(x, self.mean, self.tx)
return x
class InterFeatState(NamedTuple):
y: torch.Tensor
alpha: torch.Tensor
class IntermediateFeatureNormalizerBase(nn.Module):
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
raise NotImplementedError()
class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):
super().__init__()
self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))
rot = torch.eye(embed_dim, dtype=dtype)
if rot_per_layer:
rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)
self.register_buffer('rotation', rot.contiguous())
self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
if rot_index is None:
rot_index = index
if skip:
assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.'
prefix, x = x[:, :skip], x[:, skip:]
rotation = self._get_rotation(rot_index)
y = _run_kernel(x, self.means[index], rotation)
alpha = self.alphas[index]
if skip:
alpha = torch.cat([
torch.ones(skip, dtype=alpha.dtype, device=alpha.device),
alpha[None].expand(y.shape[1]),
]).reshape(1, -1, 1)
y = torch.cat([prefix, y], dim=1)
else:
if x.ndim == 3:
alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1)
elif x.ndim == 4:
alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:])
else:
raise ValueError(f'Unsupported input dimension: {x.ndim}')
return InterFeatState(y, alpha)
def _get_rotation(self, rot_index: int) -> torch.Tensor:
if self.rotation.ndim == 2:
return self.rotation
return self.rotation[rot_index]
class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
instances = dict()
def __init__(self, dtype: torch.dtype, device: torch.device):
super().__init__()
self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device))
@staticmethod
def get_instance(dtype: torch.dtype, device: torch.device):
instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None)
if instance is None:
instance = NullIntermediateFeatureNormalizer(dtype, device)
NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance
return instance
def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
return InterFeatState(x, self.alpha)
================================================
FILE: nit/models/nvidia_radio/radio/forward_intermediates.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
from types import MethodType
import torch
from torch import nn
from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
def _take_indices(
num_blocks: int,
n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
if isinstance(n, int):
assert n >= 0
take_indices = {x for x in range(num_blocks - n, num_blocks)}
else:
take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
return take_indices, max(take_indices)
def forward_intermediates(
model: nn.Module,
patch_extractor: Callable[[torch.Tensor], torch.Tensor],
norm: nn.Module,
num_summary_tokens: int,
num_cls_tokens: int,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
aggregation: Optional[str] = "sparse",
inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
norm_alpha_scheme = "post-alpha",
block_kwargs: Dict = None,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
""" Forward features that returns intermediates.
The Dense layer aggregation method is inspired from the paper: "Dense Connector for MLLMs"
by Yao, Huanjin et al. (2024). arXiv preprint arXiv:2405.13800}
Args:
x: Input image tensor
indices: Take last n blocks if int, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs
intermediates_only: Only return intermediate features
aggregation: intermediate layer aggregation method (sparse or dense)
norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha")
Returns:
"""
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
assert aggregation in ('sparse', 'dense'), 'Aggregation must be one of sparse or dense.'
reshape = output_fmt == 'NCHW'
intermediates = []
block_kwargs = block_kwargs or dict()
blocks = model.blocks
take_indices, max_index = _take_indices(len(blocks), indices)
take_indices = sorted(take_indices)
# forward pass
B, _, height, width = x.shape
x = patch_extractor(x)
if stop_early:
blocks = blocks[:max_index + 1]
if inter_feature_normalizer is None or norm_alpha_scheme == 'none':
inter_feature_normalizer = NullIntermediateFeatureNormalizer.get_instance(x.dtype, x.device)
assert norm_alpha_scheme in ('none', 'pre-alpha', 'post-alpha'), f'Unsupported alpha scheme: {norm_alpha_scheme}'
post_alpha_scheme = norm_alpha_scheme == 'post-alpha'
accumulator = 0
alpha_sum = 0
num_accumulated = 0
take_off = 0
for i, blk in enumerate(blocks):
x = blk(x, **block_kwargs)
if aggregation == "dense":
# Arbitrarily use the rotation matrix from the final layer in the dense group
y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
if post_alpha_scheme:
accumulator = accumulator + y
alpha_sum = alpha_sum + alpha
else:
accumulator = accumulator + (alpha * y)
alpha_sum += 1
num_accumulated += 1
if i == take_indices[take_off]:
if aggregation == "dense":
alpha = alpha_sum / num_accumulated
x_ = alpha * accumulator / num_accumulated
num_accumulated = 0
accumulator = 0
alpha_sum = 0
else:
y, alpha = inter_feature_normalizer(x, i, skip=num_summary_tokens)
x_ = alpha * y
# normalize intermediates with final norm layer if enabled
intermediates.append(norm(x_))
take_off = min(take_off + 1, len(take_indices) - 1)
# process intermediates
# split prefix (e.g. class, distill) and spatial feature tokens
prefix_tokens = [y[:, :num_cls_tokens] for y in intermediates]
intermediates = [y[:, num_summary_tokens:] for y in intermediates]
if reshape:
# reshape to BCHW output format
H = height // model.patch_size
W = width // model.patch_size
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
if not torch.jit.is_scripting() and return_prefix_tokens:
# return_prefix not support in torchscript due to poor type handling
intermediates = list(zip(prefix_tokens, intermediates))
if intermediates_only:
return intermediates
x = norm(x)
return x, intermediates
================================================
FILE: nit/models/nvidia_radio/radio/hf_model.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Callable, Dict, Optional, List, Union
from timm.models import VisionTransformer
import torch
from torch import nn
from transformers import PretrainedConfig, PreTrainedModel
from .common import RESOURCE_MAP, DEFAULT_VERSION
# Import all required modules.
from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
from .adaptor_generic import GenericAdaptor, AdaptorBase
from .adaptor_mlp import create_mlp_from_config
from .adaptor_registry import adaptor_registry
from .cls_token import ClsToken
from .dinov2_arch import dinov2_vitg14_reg
from .enable_cpe_support import enable_cpe
from .enable_spectral_reparam import configure_spectral_reparam_from_args
from .eradio_model import eradio
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .forward_intermediates import forward_intermediates
from .radio_model import create_model_from_args
from .radio_model import RADIOModel as RADIOModelBase, Resolution
from .input_conditioner import get_default_conditioner, InputConditioner
from .open_clip_adaptor import OpenCLIP_RADIO
from .vit_patch_generator import ViTPatchGenerator
from .vitdet import apply_vitdet_arch, VitDetArgs
# Register extra models
from .extra_timm_models import *
from .extra_models import *
class RADIOConfig(PretrainedConfig):
"""Pretrained Hugging Face configuration for RADIO models."""
def __init__(
self,
args: Optional[dict] = None,
version: Optional[str] = DEFAULT_VERSION,
patch_size: Optional[int] = None,
max_resolution: Optional[int] = None,
preferred_resolution: Optional[Resolution] = None,
adaptor_names: Union[str, List[str]] = None,
adaptor_configs: Dict[str, Dict[str, int]] = None,
vitdet_window_size: Optional[int] = None,
feature_normalizer_config: Optional[dict] = None,
inter_feature_normalizer_config: Optional[dict] = None,
**kwargs,
):
self.args = args
for field in ["dtype", "amp_dtype"]:
if self.args is not None and field in self.args:
# Convert to a string in order to make it serializable.
# For example for torch.float32 we will store "float32",
# for "bfloat16" we will store "bfloat16".
self.args[field] = str(args[field]).split(".")[-1]
self.version = version
resource = RESOURCE_MAP[version]
self.patch_size = patch_size or resource.patch_size
self.max_resolution = max_resolution or resource.max_resolution
self.preferred_resolution = (
preferred_resolution or resource.preferred_resolution
)
self.adaptor_names = adaptor_names
self.adaptor_configs = adaptor_configs
self.vitdet_window_size = vitdet_window_size
self.feature_normalizer_config = feature_normalizer_config
self.inter_feature_normalizer_config = inter_feature_normalizer_config
super().__init__(**kwargs)
class RADIOModel(PreTrainedModel):
"""Pretrained Hugging Face model for RADIO.
This class inherits from PreTrainedModel, which provides
HuggingFace's functionality for loading and saving models.
"""
config_class = RADIOConfig
def __init__(self, config: RADIOConfig):
super().__init__(config)
RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
args = RADIOArgs(**config.args)
self.config = config
model = create_model_from_args(args)
input_conditioner: InputConditioner = get_default_conditioner()
dtype = getattr(args, "dtype", torch.float32)
if isinstance(dtype, str):
# Convert the dtype's string representation back to a dtype.
dtype = getattr(torch, dtype)
model.to(dtype=dtype)
input_conditioner.dtype = dtype
summary_idxs = torch.tensor(
[i for i, t in enumerate(args.teachers) if t.get("use_summary", True)],
dtype=torch.int64,
)
adaptor_configs = config.adaptor_configs
adaptor_names = config.adaptor_names or []
adaptors = dict()
for adaptor_name in adaptor_names:
mlp_config = adaptor_configs[adaptor_name]
adaptor = GenericAdaptor(args, None, None, mlp_config)
adaptor.head_idx = mlp_config["head_idx"]
adaptors[adaptor_name] = adaptor
feature_normalizer = None
if config.feature_normalizer_config is not None:
# Actual normalization values will be restored when loading checkpoint weights.
feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"])
inter_feature_normalizer = None
if config.inter_feature_normalizer_config is not None:
inter_feature_normalizer = IntermediateFeatureNormalizer(
config.inter_feature_normalizer_config["num_intermediates"],
config.inter_feature_normalizer_config["embed_dim"],
rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"],
dtype=dtype)
self.radio_model = RADIOModelBase(
model,
input_conditioner,
summary_idxs=summary_idxs,
patch_size=config.patch_size,
max_resolution=config.max_resolution,
window_size=config.vitdet_window_size,
preferred_resolution=config.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)
@property
def adaptors(self) -> nn.ModuleDict:
return self.radio_model.adaptors
@property
def model(self) -> VisionTransformer:
return self.radio_model.model
@property
def input_conditioner(self) -> InputConditioner:
return self.radio_model.input_conditioner
@property
def num_summary_tokens(self) -> int:
return self.radio_model.num_summary_tokens
@property
def patch_size(self) -> int:
return self.radio_model.patch_size
@property
def max_resolution(self) -> int:
return self.radio_model.max_resolution
@property
def preferred_resolution(self) -> Resolution:
return self.radio_model.preferred_resolution
@property
def window_size(self) -> int:
return self.radio_model.window_size
@property
def min_resolution_step(self) -> int:
return self.radio_model.min_resolution_step
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
return self.radio_model.make_preprocessor_external()
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
return self.radio_model.get_nearest_supported_resolution(height, width)
def switch_to_deploy(self):
return self.radio_model.switch_to_deploy()
def forward(self, x: torch.Tensor):
return self.radio_model.forward(x)
================================================
FILE: nit/models/nvidia_radio/radio/input_conditioner.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Union, Tuple
import torch
from torch import nn
norm_t = Union[Tuple[float, float, float], torch.Tensor]
class InputConditioner(nn.Module):
def __init__(self,
input_scale: float,
norm_mean: norm_t,
norm_std: norm_t,
dtype: torch.dtype = None,
):
super().__init__()
self.dtype = dtype
self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
def forward(self, x: torch.Tensor):
y = (x - self.norm_mean) / self.norm_std
if self.dtype is not None:
y = y.to(self.dtype)
return y
def get_default_conditioner():
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
return InputConditioner(
input_scale=1.0,
norm_mean=OPENAI_CLIP_MEAN,
norm_std=OPENAI_CLIP_STD,
)
def _to_tensor(v: norm_t):
return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
================================================
FILE: nit/models/nvidia_radio/radio/open_clip_adaptor.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
import torch
from torch import nn
import torch.nn.functional as F
from .adaptor_registry import adaptor_registry, dict_t, state_t
from .adaptor_generic import GenericAdaptor
class OpenCLIP_RADIO(GenericAdaptor):
def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
super().__init__(main_config, adaptor_config, state)
import open_clip
self.oc_model = open_clip.create_model_from_pretrained(
model_name=adaptor_config['model'],
pretrained=adaptor_config['pretrained'],
return_transform=False,
)
# Unload these parameters
self.oc_model.visual = None
self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
def encode_text(self, text, normalize: bool = False):
return self.oc_model.encode_text(text, normalize=normalize)
@adaptor_registry.register_adaptor("open_clip")
def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
return OpenCLIP_RADIO(main_config, adaptor_config, state)
================================================
FILE: nit/models/nvidia_radio/radio/radio_model.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
import torch
from torch import nn
from timm.models import create_model, VisionTransformer
from types import MethodType
from .enable_cpe_support import enable_cpe
from .input_conditioner import InputConditioner
from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
from . import eradio_model
from .enable_spectral_reparam import configure_spectral_reparam_from_args
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from . import dual_hybrid_vit
class Resolution(NamedTuple):
height: int
width: int
class RADIOModel(nn.Module):
def __init__(
self,
model: nn.Module,
input_conditioner: InputConditioner,
patch_size: int,
max_resolution: int,
preferred_resolution: Resolution,
summary_idxs: Optional[torch.Tensor] = None,
window_size: int = None,
adaptors: Dict[str, AdaptorBase] = None,
feature_normalizer: Optional[FeatureNormalizer] = None,
inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None,
):
super().__init__()
self.model = model
self.input_conditioner = input_conditioner
if summary_idxs is not None:
self.register_buffer('summary_idxs', summary_idxs)
else:
self.summary_idxs = None
self._preferred_resolution = preferred_resolution
self._patch_size = patch_size
self._max_resolution = max_resolution
self._window_size = window_size
adaptors = adaptors or dict()
self.adaptors = nn.ModuleDict(adaptors)
if feature_normalizer is None:
feature_normalizer = nn.Identity()
self.feature_normalizer = feature_normalizer
self.inter_feature_normalizer = inter_feature_normalizer
@property
def num_summary_tokens(self) -> int:
if hasattr(self.model, 'num_summary_tokens'):
return self.model.num_summary_tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
return patch_gen.num_skip
elif getattr(self.model, 'global_pool', None) == 'avg':
return 0
return 1
@property
def num_cls_tokens(self) -> int:
if hasattr(self.model, 'num_cls_tokens'):
return self.model.num_cls_tokens
patch_gen = getattr(self.model, 'patch_generator', None)
if patch_gen is not None:
return patch_gen.num_cls_tokens
elif getattr(self.model, 'global_pool', None) == 'avg':
return 0
return 1
@property
def patch_size(self) -> int:
if self._patch_size is not None:
return self._patch_size
if hasattr(self.model, "patch_size"):
return self.model.patch_size
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
return patch_gen.patch_size
return None
@property
def max_resolution(self) -> int:
return self._max_resolution
@property
def preferred_resolution(self) -> Resolution:
return self._preferred_resolution
@property
def window_size(self) -> int:
return self._window_size
@property
def min_resolution_step(self) -> int:
res = self.patch_size
if self.window_size is not None:
res *= self.window_size
return res
@property
def blocks(self) -> Iterable[nn.Module]:
blocks = getattr(self.model, 'blocks', None)
if blocks is not None:
return blocks
return None
@property
def embed_dim(self) -> int:
return self.model.embed_dim
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
ret = self.input_conditioner
self.input_conditioner = nn.Identity()
return ret
def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
height = int(round(height / self.min_resolution_step) * self.min_resolution_step)
width = int(round(width / self.min_resolution_step) * self.min_resolution_step)
height = max(height, self.min_resolution_step)
width = max(width, self.min_resolution_step)
return Resolution(height=height, width=width)
def switch_to_deploy(self):
fn = getattr(self.model, 'switch_to_deploy', None)
if fn is not None:
fn()
def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
'''
Forward process for model.
Args:
x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
otherwise `x` is expected to be mean centered with unit standard deviation.
feature_format: ['NLC', 'NCHW'] - The output format for the features.
'''
res_step = self.min_resolution_step
if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
'`self.get_nearest_supported_resolution(, ) is provided as a convenience API. '
f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
x = self.input_conditioner(x)
y = self.model.forward_features(x)
ret = self._extract_final(x, y, feature_fmt=feature_fmt)
return ret
def forward_pack(self, x: List[torch.Tensor], feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
'''
Forward process for model.
Args:
x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
otherwise `x` is expected to be mean centered with unit standard deviation.
feature_format: ['NLC', 'NCHW'] - The output format for the features.
'''
res_step = self.min_resolution_step
for _x in x:
if res_step is not None and (_x.shape[-2] % res_step != 0 or _x.shape[-1] % res_step != 0):
raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
'`self.get_nearest_supported_resolution(, ) is provided as a convenience API. '
f'Input: {_x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*_x.shape[-2:])}')
x = [self.input_conditioner(_x) for _x in x]
y, cu_seqlens = self.model.forward_features(x)
all_summary, spatial_features = [], []
num_cls_tokens = self.model.patch_generator.num_cls_tokens
num_skip = self.model.patch_generator.num_skip
for i in range(len(cu_seqlens)-1):
summary = y[cu_seqlens[i]: cu_seqlens[i+1]][: num_cls_tokens]
all_feat = y[cu_seqlens[i]: cu_seqlens[i+1]][num_skip :]
all_summary.append(summary)
spatial_features.append(all_feat)
all_summary = torch.cat(all_summary)
spatial_features = torch.cat(spatial_features)
return all_summary, spatial_features
def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'):
if isinstance(self.model, VisionTransformer):
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, patch_gen.num_skip :]
elif self.model.global_pool == "avg":
all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
bb_summary = all_summary
all_feat = y
else:
all_summary = y[:, 0]
bb_summary = all_summary
all_feat = y[:, 1:]
elif isinstance(self.model, eradio_model.ERADIO):
_, f = y
all_feat = f.flatten(2).transpose(1, 2)
all_summary = all_feat.mean(dim=1)
bb_summary = all_summary
elif isinstance(y, (list, tuple)):
all_summary, all_feat = y
bb_summary = all_summary
else:
all_summary = y[:, :self.num_cls_tokens]
if self.summary_idxs is not None and all_summary.shape[1] > 1:
if all_summary.shape[1] == 1:
# Create dummy duplicates
all_summary = all_summary.expand(-1, 128, -1)
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, self.num_summary_tokens:]
all_feat = self.feature_normalizer(all_feat)
if feature_fmt == 'NCHW':
fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2])
.permute(0, 3, 1, 2)
)
elif feature_fmt == 'NLC':
fmt_feat = all_feat
else:
raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]')
ret = RadioOutput(bb_summary.flatten(1), fmt_feat)
if self.adaptors:
ret = dict(backbone=ret)
for name, adaptor in self.adaptors.items():
if all_summary.ndim == 3:
if all_summary.shape[1] == 1:
summary = all_summary[:, 0]
else:
summary = all_summary[:, adaptor.head_idx]
else:
summary = all_summary
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
v = adaptor(ada_input).to(torch.float32)
ret[name] = v
return ret
def forward_intermediates(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
return_prefix_tokens: bool = False,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
aggregation: Optional[str] = "sparse",
norm_alpha_scheme: Optional[str] = "post-alpha",
) -> List[RadioOutput]:
""" Forward features that returns intermediates.
Args:
x: Input image tensor
indices: Take last n blocks if int, select matching indices if sequence
return_prefix_tokens: Return both prefix and spatial intermediate tokens
norm: Apply norm layer to all intermediates
stop_early: Stop iterating over blocks when last desired intermediate hit
output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC
intermediates_only: Only return intermediate features
aggregation: intermediate layer aggregation method (sparse or dense).
Dense accumulation is done by averaging the features in each group.
norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none")
Only affects dense aggregation
Returns:
List of RadioOutput objects.
"""
x = self.input_conditioner(x)
intermediates = self.model.forward_intermediates(
x,
indices=indices,
return_prefix_tokens=return_prefix_tokens,
norm=norm,
stop_early=stop_early,
output_fmt=output_fmt,
intermediates_only=intermediates_only,
aggregation=aggregation,
inter_feature_normalizer=self.inter_feature_normalizer,
norm_alpha_scheme=norm_alpha_scheme,
)
if not intermediates_only:
final, intermediates = intermediates
def prepare_summary(summ: Optional[torch.Tensor]):
if summ is None:
return summ
if self.summary_idxs is not None and summ.shape[1] > 1:
summ = summ[:, self.summary_idxs]
return summ.flatten(1)
if return_prefix_tokens:
radio_outputs = [
RadioOutput(prepare_summary(summary), features)
for summary, features in intermediates
]
else:
radio_outputs = intermediates
if intermediates_only:
return radio_outputs
else:
final = self._extract_final(x, final, feature_fmt=output_fmt)
return final, radio_outputs
def create_model_from_args(args) -> nn.Module:
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
# Skip weight initialization unless it's explicitly requested.
weight_init = args.model_kwargs.pop("weight_init", "skip")
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
weight_init=weight_init,
**args.model_kwargs,
)
if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
model.norm = nn.Identity()
model.head = nn.Identity()
if args.cpe_max_size is not None:
uq_teachers = set(t['name'] for t in args.teachers)
enable_cpe(
model,
args.cpe_max_size,
num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1,
register_multiple=getattr(args, 'register_multiple', None),
num_registers=getattr(args, 'cpe_num_registers', None),
support_packing=args.support_packing,
)
return model
================================================
FILE: nit/models/nvidia_radio/radio/vision_transformer_xpos.py
================================================
import math
from typing import Final, List, Optional, Tuple, Union
from einops import rearrange
from timm.models import register_model
import torch
from torch import Type, nn
from torch.nn import functional as F
from torch.nn.init import xavier_normal_, xavier_uniform_, zeros_
from .forward_intermediates import forward_intermediates
def _get_init_scale(num_encoder_layers: int, num_decoder_layers: int, is_encoder: bool):
if num_encoder_layers > 0 and num_decoder_layers == 0:
return math.sqrt(math.log(2 * num_encoder_layers))
if num_decoder_layers > 0 and num_encoder_layers == 0:
return math.sqrt(math.log(2 * num_decoder_layers))
if is_encoder:
# Both encoders and decoders
return math.sqrt(
0.33 * math.log(3 * num_decoder_layers) * math.log(2 * num_encoder_layers)
)
return math.sqrt(math.log(3 * num_decoder_layers))
# [1,2] [1,1,2,2]
# [3,4] -> [3,3,4,4]
# [5,6] [5,5,6,6]
def duplicate_interleave(m):
return m.view(-1, 1).repeat(1, 2).view(m.shape[0], -1)
# 0,1,2,3,4,5,6,7 -> -1,0,-3,2,-5,4,-7,6
def rotate_every_two(x):
x1 = x[:, :, ::2]
x2 = x[:, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
class XPosEmbedding2D(torch.nn.Module):
"""Implementation of xPos based on RotaryEmbedding from GPT-NeoX.
This implementation is designed to operate on queries and keys that are compatible with
[batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
"""
def __init__(
self,
head_dim: int,
base=50000,
scale_base=512
):
super().__init__()
half_dim = head_dim // 2
self.half_dim = half_dim
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.token_shape_cached = None
self.batch_size_cached = None
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None
self.scale_cached: torch.Tensor | None = None
self.scale_base = scale_base
self.register_buffer("scale",
(torch.arange(0, half_dim, 2) + 0.4 * half_dim) / (1.4 * half_dim))
def cos_sin(
self,
token_shape: Tuple[int, int],
device="cuda",
dtype=torch.bfloat16,
) -> torch.Tensor:
if token_shape != self.token_shape_cached:
self.token_shape_cached = token_shape
y = torch.arange(token_shape[0], device=device, dtype=self.inv_freq.dtype)
x = torch.arange(token_shape[1], device=device, dtype=self.inv_freq.dtype)
x, y = torch.meshgrid(x, y, indexing='xy')
y_freqs = torch.einsum("i,j->ij", y.flatten(), self.inv_freq)
x_freqs = torch.einsum("i,j->ij", x.flatten(), self.inv_freq)
y_scales = self.scale ** y.flatten().div(self.scale_base)[:, None]
x_scales = self.scale ** x.flatten().div(self.scale_base)[:, None]
freqs = torch.cat([y_freqs, x_freqs], dim=-1)
emb = torch.repeat_interleave(freqs, repeats=2, dim=-1)
scales = torch.cat([y_scales, x_scales], dim=-1)
scales = torch.repeat_interleave(scales, repeats=2, dim=-1)
if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()
self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]
self.scale_cached = scales[None, :, :]
self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)
self.scale_cached = self.scale_cached.type(dtype)
return self.cos_cached, self.sin_cached, self.scale_cached
def forward(self, q: torch.Tensor, k: torch.Tensor, token_shape: Tuple[int, int]):
batch, seq_len, head_dim = q.shape
cos, sin, scale = self.cos_sin(token_shape, q.device, q.dtype)
# scale = self.scale**torch.arange(seq_len).to(self.scale).div(self.scale_base)[:, None]
# scale = torch.repeat_interleave(scale, 2, dim=-1).to(q.device)
# scale = torch.cat([scale, scale], dim=-1)
# scale = 1
return (
(q * cos * scale) + (rotate_every_two(q) * sin * scale),
(k * cos * (1 / scale)) + (rotate_every_two(k) * sin * (1 / scale)),
)
class MagnetoAttention(nn.Module):
def __init__(self, d_model: int, n_head: int, pos_emb: XPosEmbedding2D):
super().__init__()
self.num_heads = n_head
self.head_dim = d_model // n_head
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.proj = nn.Linear(d_model, d_model)
self.pos_emb = pos_emb
self.norm0 = nn.LayerNorm(d_model)
self.norm1 = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
B, N, C = x.shape
x = self.norm0(x)
qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
q, k, v = qkv.unbind(0)
q_pref = q[:, :num_prefix_tokens]
q_patch = q[:, num_prefix_tokens:]
k_pref = k[:, :num_prefix_tokens]
k_patch = k[:, num_prefix_tokens:]
q_patch, k_patch = self.pos_emb(q_patch, k_patch, patch_shape)
q = torch.cat([q_pref, q_patch], dim=1)
k = torch.cat([k_pref, k_patch], dim=1)
def head_reshape(t: torch.Tensor):
return t.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
q = head_reshape(q)
k = head_reshape(k)
v = head_reshape(v)
x = F.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.norm1(x)
x = self.proj(x)
return x
def _reset_parameters(self):
xavier_uniform_(self.qkv.weight)
if self.qkv.bias is not None:
zeros_(self.qkv.bias)
xavier_normal_(self.proj.weight)
zeros_(self.proj.bias)
class MagnetoTransformerEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, pos_emb: XPosEmbedding2D,
num_encoder_layers: int, num_decoder_layers: int = 0,
dim_mhsa: int = 0,
dim_feedforward: int = 2048,
layer_norm_eps: float = 1e-5,
batch_first: bool = True):
super().__init__()
if dim_mhsa == 0:
dim_mhsa = d_model
self._num_encoder_layers = num_encoder_layers
self._num_decoder_layers = num_decoder_layers
self.attn = MagnetoAttention(d_model, nhead, pos_emb)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.linear2 = nn.Linear(d_model, dim_feedforward)
self.norm3 = nn.LayerNorm(dim_feedforward, eps=layer_norm_eps)
self.linear3 = nn.Linear(dim_feedforward, d_model)
def initialize(self):
gamma = _get_init_scale(self._num_encoder_layers, self._num_decoder_layers, is_encoder=True)
# Magneto Initialization
for mod in self.children():
if isinstance(mod, nn.Linear):
xavier_normal_(mod.weight.data, gamma)
elif isinstance(mod, MagnetoAttention):
mod._reset_parameters()
def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
x = x + self._sa_block(x, num_prefix_tokens, patch_shape)
x = x + self._ff_block(x)
return x
def _sa_block(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:
x = self.attn(x, num_prefix_tokens, patch_shape)
return x
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.norm2(x)
x = self.linear2(x)
x = F.gelu(x)
x = self.norm3(x)
x = self.linear3(x)
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size: Final[bool]
def __init__(
self,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
num_cls_tokens: int = 1,
num_reg_tokens: int = 0,
) -> None:
"""
Args:
patch_size: Patch size.
in_chans: Number of image input channels.
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
num_cls_tokens: Number of cls tokens
num_reg_tokens: Number of register tokens.
block_fn: Transformer block layer.
"""
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_cls_tokens = num_cls_tokens
self.num_reg_tokens = num_reg_tokens
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.prefix_buffer = nn.Parameter(torch.randn(1, self.num_prefix_tokens, embed_dim) * .02)
pos_emb = XPosEmbedding2D(embed_dim)
self.blocks = nn.ModuleList([
MagnetoTransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
num_encoder_layers=depth,
num_decoder_layers=0,
dim_feedforward=int(embed_dim * mlp_ratio),
pos_emb=pos_emb,
)
for _ in range(depth)
])
for block in self.blocks:
block.initialize()
@property
def num_prefix_tokens(self):
return self.num_cls_tokens + self.num_reg_tokens
@property
def num_summary_tokens(self):
return self.num_prefix_tokens
def forward_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x, patch_shape = self._patchify(x)
for block in self.blocks:
x = block(x, self.num_prefix_tokens, patch_shape)
summary = x[:, :self.num_cls_tokens]
features = x[:, self.num_prefix_tokens:]
return summary, features
def forward_intermediates(self, x: torch.Tensor, norm: bool = False, **kwargs):
patch_shape = tuple(d // self.patch_size for d in x.shape[-2:])
def patch_extractor(x: torch.Tensor):
x, _ = self._patchify(x)
return x
return forward_intermediates(
self,
patch_extractor=patch_extractor,
num_summary_tokens=self.num_prefix_tokens,
num_cls_tokens=self.num_cls_tokens,
norm=lambda y: y,
x=x,
block_kwargs=dict(num_prefix_tokens=self.num_prefix_tokens, patch_shape=patch_shape),
**kwargs,
)
def _patchify(self, x: torch.Tensor):
x = self.patch_embed(x)
patch_shape = x.shape[-2:]
x = rearrange(x, 'b c h w -> b (h w) c')
prefix = self.prefix_buffer.expand(x.shape[0], -1, -1)
x = torch.cat([prefix, x], dim=1)
return x, patch_shape
@register_model
def vit_base_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
return VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12,
num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
@register_model
def vit_large_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
return VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16,
num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
@register_model
def vit_huge_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
return VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16,
num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
@register_model
def vit_giant_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
return VisionTransformer(patch_size=16, embed_dim=1408, depth=40, num_heads=16,
num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
@register_model
def vit_bigG_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:
return VisionTransformer(patch_size=16, embed_dim=1664, depth=48, num_heads=16,
num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)
================================================
FILE: nit/models/nvidia_radio/radio/vit_patch_generator.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
from typing import Union, Tuple, Optional
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange
from .cls_token import ClsToken
input_dim_t = Union[int, Tuple[int, int]]
try:
# raise ImportError()
from indirect_grid_sample import indirect_grid_sample
except ImportError:
indirect_grid_sample = None
class ViTPatchGenerator(nn.Module):
def __init__(self,
patch_size: int,
embed_dim: int,
input_dims: input_dim_t,
abs_pos: bool = True,
normalize_patches: bool = False,
cls_token: bool = False,
max_input_dims: Optional[input_dim_t] = None,
pos_dropout: float = 0.0,
return_pos_enc: bool = False,
num_cls_tokens: int = 1,
register_multiple: Optional[int] = None,
num_registers: Optional[int] = None,
patch_bias: bool = False,
device=None, dtype=None,
):
super().__init__()
if isinstance(input_dims, int):
input_dims = (input_dims, input_dims)
if max_input_dims is None:
max_input_dims = input_dims
if isinstance(max_input_dims, int):
max_input_dims = (max_input_dims, max_input_dims)
max_input_dims = tuple(
int(math.ceil(d / patch_size) * patch_size)
for d in max_input_dims
)
self.cpe_mode = max_input_dims != input_dims
self.pos_dropout = pos_dropout
self.return_pos_enc = return_pos_enc
factory = dict(device=device, dtype=dtype)
self.patch_size = patch_size
self.abs_pos = abs_pos
self.embed_dim = embed_dim
self.num_rows = max_input_dims[0] // patch_size
self.num_cols = max_input_dims[1] // patch_size
self.input_dims = tuple(d // patch_size for d in input_dims)
self.num_patches = self.num_rows * self.num_cols
self.max_input_dims = max_input_dims
self.im_to_patches = Im2Patches(patch_size)
self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias, **factory)
if abs_pos:
scale = embed_dim ** -0.5
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
self.cls_token = ClsToken(
embed_dim,
num_tokens=num_cls_tokens,
enabled=cls_token,
register_multiple=register_multiple,
num_registers=num_registers,
)
self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
@property
def apply_cls_token(self):
return self.cls_token.enabled
@property
def num_cls_tokens(self):
return self.cls_token.num_tokens
@property
def num_cls_patches(self):
return self.cls_token.num_patches
@property
def num_registers(self):
return self.cls_token.num_registers
@property
def num_skip(self):
return self.num_cls_tokens + self.num_registers
def no_weight_decay(self):
return [
'pos_embed',
]
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
if src_embed.shape != targ_embed.shape:
src_size = int(math.sqrt(src_embed.shape[1]))
assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
targ_embed.data.copy_(src_embed)
def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
if src_proj_weight.shape != targ_proj_weight.shape:
src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
targ_proj_weight.data.copy_(src_proj_weight)
def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
patches = self.im_to_patches(x)
patches = self.embedder(patches)
return patches
def apply_pos_enc(self,
patches: torch.Tensor,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
if not self.abs_pos:
return patches
pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
if self.training and self.pos_dropout > 0:
keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
pos_enc_drop = torch.where(keeps, pos_enc, 0)
else:
pos_enc_drop = pos_enc
return patches + pos_enc_drop, pos_enc
def get_pos_enc(self,
batch_size: int,
patch_idxs: Optional[torch.Tensor] = None,
input_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
if input_size is None:
input_dims = self.input_dims
else:
input_dims = tuple(d // self.patch_size for d in input_size)
pos_embed = self._get_pos_embeddings(batch_size, input_dims)
if patch_idxs is None:
return pos_embed
exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
return pos_embed
def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
if (self.num_rows, self.num_cols) == input_dims:
return self.pos_embed
pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
def window_select(pos_embed):
if input_dims[0] < pos_embed.shape[-2]:
pos_embed = pos_embed[..., :input_dims[0], :]
if input_dims[1] < pos_embed.shape[-1]:
pos_embed = pos_embed[..., :, :input_dims[1]]
return pos_embed
if self.cpe_mode:
if self.training:
min_scale = math.sqrt(0.1)
scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
).to(pos_embed.dtype)
else:
# i_rows, i_cols = input_dims
# p_rows, p_cols = pos_embed.shape[2:]
# if i_rows <= p_rows and i_cols <= p_cols:
# left = (p_cols - i_cols) // 2
# top = (p_rows - i_rows) // 2
# pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
# else:
max_dim = max(input_dims)
pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
return pos_embed
class Im2Patches(nn.Module):
def __init__(self, patch_size: int):
super().__init__()
self.patch_size = patch_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.patch_size == 1:
patches = x.flatten(2)
patches = patches.permute(0, 2, 1)
return patches
py = x.shape[-2] // self.patch_size
px = x.shape[-1] // self.patch_size
patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
py=py, yy=self.patch_size,
px=px, xx=self.patch_size,
)
return patches
class ViTPatchLinear(nn.Linear):
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
super().__init__(
3 * (patch_size ** 2),
embed_dim,
bias=bias,
**factory
)
self.patch_size = patch_size
================================================
FILE: nit/models/nvidia_radio/radio/vitdet.py
================================================
from collections import defaultdict
from contextlib import contextmanager
from logging import getLogger
import math
import sys
from typing import List, Union, Iterable
import numpy as np
import torch
from torch import nn
from timm.models import VisionTransformer
from einops import rearrange
from .extra_models import DinoWrapper
DEFAULT_NUM_WINDOWED = 5
DEFAULT_NUM_GLOBAL = 4
class VitDetArgs:
def __init__(self,
window_size: int,
num_summary_tokens: int,
num_windowed: int = None,
num_global: int = None,
):
self.window_size = window_size
self.num_summary_tokens = num_summary_tokens
self.num_windowed = num_windowed
self.num_global = num_global
def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs):
if isinstance(model, VisionTransformer):
patch_embed = getattr(model, 'patch_generator', model.patch_embed)
return ViTDetHook(patch_embed, model.blocks, args)
elif isinstance(model, DinoWrapper):
inner = model.inner
patch_embed = getattr(inner, 'patch_generator', inner.patch_embed)
return ViTDetHook(patch_embed, inner.blocks, args)
else:
print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
class ViTDetHook:
def __init__(self,
embedder: nn.Module,
blocks: nn.Sequential,
args: VitDetArgs,
):
self.blocks = blocks
self.num_summary_tokens = args.num_summary_tokens
self.window_size = args.window_size
self._input_resolution = None
self._num_windows = None
self._cls_patch = None
self._order_cache = dict()
embedder.register_forward_pre_hook(self._enter_model)
# This will decide if we window-fy the patches
# and enable vit-det for this iteration, and if so,
# rearrange the patches for efficient mode switching
blocks.register_forward_pre_hook(self._enter_blocks)
is_global = True
if args.num_windowed is not None:
period = args.num_windowed + 1
else:
num_global = args.num_global or DEFAULT_NUM_GLOBAL
period = max(len(blocks) // num_global, 1)
for i, layer in enumerate(blocks[:-1]):
ctr = i % period
if ctr == 0:
layer.register_forward_pre_hook(self._to_windows)
is_global = False
elif ctr == period - 1:
layer.register_forward_pre_hook(self._to_global)
is_global = True
# Always ensure the final layer is a global layer
if not is_global:
blocks[-1].register_forward_pre_hook(self._to_global)
blocks.register_forward_hook(self._exit_model)
def _enter_model(self, _, input: List[torch.Tensor]):
self._input_resolution = input[0].shape[-2:]
def _enter_blocks(self, _, input: List[torch.Tensor]):
# print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
patches = input[0]
patches = self._rearrange_patches(patches)
return (patches,) + input[1:]
def _to_windows(self, _, input: List[torch.Tensor]):
patches = input[0]
if self.num_summary_tokens:
self._cls_patch = patches[:, :self.num_summary_tokens]
patches = patches[:, self.num_summary_tokens:]
patches = rearrange(
patches, 'b (p t) c -> (b p) t c',
p=self._num_windows, t=self.window_size ** 2,
)
return (patches,) + input[1:]
def _to_global(self, _, input: List[torch.Tensor]):
patches = input[0]
patches = rearrange(
patches, '(b p) t c -> b (p t) c',
p=self._num_windows, t=self.window_size ** 2,
b=patches.shape[0] // self._num_windows,
)
if self.num_summary_tokens:
patches = torch.cat([
self._cls_patch,
patches,
], dim=1)
return (patches,) + input[1:]
def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
# Return patches to their original order
patch_order = self._order_cache[self._input_resolution][0]
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
ret_patches = torch.empty_like(patches)
ret_patches = torch.scatter(
ret_patches,
dim=1,
index=patch_order,
src=patches,
)
return ret_patches
def _rearrange_patches(self, patches: torch.Tensor):
# We rearrange the patches so that we can efficiently
# switch between windowed and global mode by just
# reshaping the tensor
patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
if patch_order is None:
num_feat_patches = patches.shape[1] - self.num_summary_tokens
num_pixels = self._input_resolution[0] * self._input_resolution[1]
patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
rows = self._input_resolution[-2] // patch_size
cols = self._input_resolution[-1] // patch_size
w_rows = rows // self.window_size
w_cols = cols // self.window_size
patch_order = torch.arange(0, num_feat_patches, device=patches.device)
patch_order = rearrange(
patch_order, '(wy py wx px) -> (wy wx py px)',
wy=w_rows, wx=w_cols,
py=self.window_size, px=self.window_size,
)
if self.num_summary_tokens:
patch_order = torch.cat([
torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
patch_order + self.num_summary_tokens,
])
self._num_windows = w_rows * w_cols
self._order_cache[self._input_resolution] = (
patch_order,
self._num_windows,
)
patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
patches = torch.gather(patches, dim=1, index=patch_order)
return patches
================================================
FILE: nit/models/utils/convs.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from nit.models.efficientvit.models.nn.ops import ConvLayer
from nit.models.efficientvit.models.nn.act import build_act
from nit.models.efficientvit.models.utils import val2tuple
def create_conv_1(conv_type, in_channels, out_channels, norm, act_func, groups=1):
'''
conv_type: dwconv_3x3_1, dsconv_3x3_1, dgconv_3x3_1
'''
if conv_type == None or conv_type == "":
return nn.Identity()
splited_conv_type = conv_type.split('_')
conv_type = splited_conv_type[0]
kernel_size = int(splited_conv_type[1].split('x')[0])
stride = int(splited_conv_type[2])
if conv_type == 'dwconv':
return DWConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func)
elif conv_type == 'dsconv':
return DSConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func)
elif conv_type == 'dgconv':
return DGConv(in_channels, out_channels, kernel_size, stride, groups, norm=norm, act_func=act_func)
else:
return nn.Identity()
def create_conv_2(conv_type, in_channels, out_channels, mid_channels):
'''
conv_type: mbconv_3x3_1, fusedmbconv_3x3_1, glumbconv_3x3_1
'''
if conv_type == None or conv_type == "":
return nn.Identity()
splited_conv_type = conv_type.split('_')
conv_type = splited_conv_type[0]
kernel_size = int(splited_conv_type[1].split('x')[0])
stride = int(splited_conv_type[2])
if conv_type == 'mbconv':
return MBConv(in_channels, out_channels, kernel_size, stride, mid_channels)
elif conv_type == 'fusedmbconv':
return FusedMBConv(in_channels, out_channels, kernel_size, stride, mid_channels)
elif conv_type == 'glumbconv':
return GLUMBConv(in_channels, out_channels, kernel_size, stride, mid_channels)
else:
return nn.Identity()
class DWConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
use_bias=True,
norm="bn2d",
act_func="relu6",
):
super(DWConv, self).__init__()
self.depth_conv = ConvLayer(
in_channels,
out_channels,
kernel_size,
stride,
groups=in_channels,
norm=norm,
act_func=act_func,
use_bias=use_bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depth_conv(x)
return x
class DSConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
use_bias=(True, True),
norm=("bn2d", "bn2d"),
act_func=("relu6", None),
):
super(DSConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
self.depth_conv = ConvLayer(
in_channels,
in_channels,
kernel_size,
stride,
groups=in_channels,
norm=norm[0],
act_func=act_func[0],
use_bias=use_bias[0],
)
self.point_conv = ConvLayer(
in_channels,
out_channels,
1,
norm=norm[1],
act_func=act_func[1],
use_bias=use_bias[1],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class DGConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
groups=16,
use_bias=(True, True),
norm=("bn2d", "bn2d"),
act_func=("relu6", None),
):
super(DGConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
self.depth_conv = ConvLayer(
in_channels,
in_channels,
kernel_size,
stride,
groups=in_channels,
norm=norm[0],
act_func=act_func[0],
use_bias=use_bias[0],
)
self.point_conv = ConvLayer(
in_channels,
out_channels,
1,
groups=groups,
norm=norm[1],
act_func=act_func[1],
use_bias=use_bias[1],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class MBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
use_bias=True,
norm=("bn2d", "bn2d", "bn2d"),
act_func=("relu6", "relu6", None),
):
super(MBConv, self).__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act_func = val2tuple(act_func, 3)
mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
self.inverted_conv = ConvLayer(
in_channels,
mid_channels,
1,
stride=1,
norm=norm[0],
act_func=act_func[0],
use_bias=use_bias[0],
)
self.depth_conv = ConvLayer(
mid_channels,
mid_channels,
kernel_size,
stride=stride,
groups=mid_channels,
norm=norm[1],
act_func=act_func[1],
use_bias=use_bias[1],
)
self.point_conv = ConvLayer(
mid_channels,
out_channels,
1,
norm=norm[2],
act_func=act_func[2],
use_bias=use_bias[2],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inverted_conv(x)
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class FusedMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
groups=1,
use_bias=True,
norm=("bn2d", "bn2d"),
act_func=("relu6", None),
):
super().__init__()
use_bias = val2tuple(use_bias, 2)
norm = val2tuple(norm, 2)
act_func = val2tuple(act_func, 2)
mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
self.spatial_conv = ConvLayer(
in_channels,
mid_channels,
kernel_size,
stride,
groups=groups,
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
)
self.point_conv = ConvLayer(
mid_channels,
out_channels,
1,
use_bias=use_bias[1],
norm=norm[1],
act_func=act_func[1],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.spatial_conv(x)
x = self.point_conv(x)
return x
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
use_bias=True,
norm=(None, None, "ln2d"),
act_func=("silu", "silu", None),
):
super().__init__()
use_bias = val2tuple(use_bias, 3)
norm = val2tuple(norm, 3)
act_func = val2tuple(act_func, 3)
mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels
self.glu_act = build_act(act_func[1], inplace=False)
self.inverted_conv = ConvLayer(
in_channels,
mid_channels * 2,
1,
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
)
self.depth_conv = ConvLayer(
mid_channels * 2,
mid_channels * 2,
kernel_size,
stride=stride,
groups=mid_channels * 2,
use_bias=use_bias[1],
norm=norm[1],
act_func=None,
)
self.point_conv = ConvLayer(
mid_channels,
out_channels,
1,
use_bias=use_bias[2],
norm=norm[2],
act_func=act_func[2],
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
gate = self.glu_act(gate)
x = x * gate
x = self.point_conv(x)
return x
================================================
FILE: nit/models/utils/funcs.py
================================================
import torch
from torch import Tensor
from typing import List, Tuple
from itertools import chain
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
================================================
FILE: nit/models/utils/norms.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from functools import partial
import torch
import torch.nn as nn
import triton
import triton.language as tl
import torch.nn.functional as F
def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
"""
Creates the specified normalization layer based on the norm_type.
Args:
norm_type (str): The type of normalization layer to create.
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
Returns:
The created normalization layer.
Raises:
NotImplementedError: If an unknown norm_type is provided.
"""
if norm_type == None or norm_type == "":
return nn.Identity()
norm_type = norm_type.lower() # Normalize to lowercase
if norm_type == "layernorm":
return nn.LayerNorm(dim, eps=eps, bias=False)
elif norm_type == "np_layernorm":
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "np_layernorm_32":
return FP32_Layernorm(dim, eps=eps, elementwise_affine=False, bias=True)
elif norm_type == "layernorm_32":
return FP32_Layernorm(dim, eps=eps, bias=True)
elif norm_type == "rmsnorm":
return RMSNorm(dim, include_weight=True, eps=eps)
elif norm_type == "np_rmsnorm":
return RMSNorm(dim, include_weight=False, eps=1e-6)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=1/65536)
elif norm_type == "fused_rmsnorm_32":
return FusedRMSNorm32(dim, eps=1e-6)
elif norm_type == 'none':
return nn.Identity()
else:
return nn.Identity()
class FP32_Layernorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
if self.bias == None and self.weight == None:
return F.layer_norm(
input=inputs.float(),
normalized_shape=self.normalized_shape,
eps=self.eps
).to(origin_dtype)
elif self.bias == None:
return F.layer_norm(
input=inputs.float(),
normalized_shape=self.normalized_shape,
weight=self.weight.float(),
eps=self.eps
).to(origin_dtype)
else:
return F.layer_norm(
input=inputs.float(),
normalized_shape=self.normalized_shape,
weight=self.weight.float(),
bias=self.bias.float(),
eps=self.eps
).to(origin_dtype)
class FusedRMSNorm(nn.Module):
"""Fused RMS Norm, wraps a fused Triton Kernel"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.fused_rms_norm_fn = fused_rms_norm_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""leverages Triton Fused RMS Norm kernel"""
return self.fused_rms_norm_fn(
x,
self.weight,
eps=self.eps,
)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
class FusedRMSNorm32(nn.Module):
"""Fused RMS Norm, wraps a fused Triton Kernel"""
def __init__(
self,
dim: int,
eps: float = 1e-6,
):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.fused_rms_norm_fn = fused_rms_norm_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""leverages Triton Fused RMS Norm kernel"""
dtype = x.dtype
return self.fused_rms_norm_fn(
x.to(torch.float32),
self.weight,
eps=self.eps,
).to(dtype)
def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
class RMSNorm(nn.Module):
def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6, **block_kwargs):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
include_weight: bool: Whether include weight in the normalization
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
if include_weight:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if self.weight == None:
return output
else:
return output * self.weight
# FusedRMSNorm in Triton
# Credit
# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py
# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_fwd_kernel(
X,
stride_x,
Y,
stride_y,
W,
Rstd,
eps,
M, # num rows
N, # num cols
block_N: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, block_N)
# Load input data and weights
mask = cols < N
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Compute mean and variance
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Store the reciprocal standard deviation
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + row * stride_y + cols, y, mask=mask)
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["N"],
)
@triton.jit
def _rms_norm_bwd_kernel_sm(
X,
stride_x,
W,
DY,
stride_dy,
DX,
stride_dx,
Rstd,
DW,
eps,
M, # num rows
N, # num cols
rows_per_program,
block_N: tl.constexpr,
):
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
cols = tl.arange(0, block_N)
mask = cols < N
# Load weights
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
# Accumulate gradients for weights
dw = tl.zeros((block_N,), dtype=tl.float32)
row_end = min(row_start + rows_per_program, M)
for row in range(row_start, row_end):
# Load input, output gradient, and reciprocal standard deviation
x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)
rstd = tl.load(Rstd + row)
# Compute normalized input and gradients
x_hat = x * rstd
wdy = w * dy
dw += dy * x_hat
c1 = tl.sum(x_hat * wdy, axis=0) / N
dx = (wdy - x_hat * c1) * rstd
# Store input gradient
tl.store(DX + row * stride_dx + cols, dx, mask=mask)
# Store weight gradients
tl.store(DW + row_block_id * N + cols, dw, mask=mask)
class TritonFusedRMSNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, eps):
x_shape_start = x.shape
# Flatten input
x = x.view(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if weight.stride(-1) != 1:
weight = weight.contiguous()
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (M,)
_rms_norm_fwd_kernel[grid](
x,
x.stride(0),
y,
y.stride(0),
weight,
rstd,
eps,
M,
N,
block_N,
)
ctx.eps = eps
ctx.save_for_backward(x, weight, rstd)
ctx.x_shape_start = x_shape_start
y = y.reshape(x_shape_start)
return y
@staticmethod
def backward(ctx, dy):
x, weight, rstd = ctx.saved_tensors
eps = ctx.eps
x_shape_start = ctx.x_shape_start
# Flatten input and output gradients
dy = dy.view(-1, dy.shape[-1])
if dy.stride(-1) != 1:
dy = dy.contiguous()
M, N = dy.shape
dx = torch.empty_like(x)
dw = torch.empty_like(weight)
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
_dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
max_size = 65536 // x.element_size()
block_N = min(max_size, triton.next_power_of_2(N))
rows_per_sm = math.ceil(M / sm_count)
if N > block_N:
raise ValueError(f"N {N} must be <= {block_N=}")
grid = lambda meta: (sm_count,)
_rms_norm_bwd_kernel_sm[grid](
x,
x.stride(0),
weight,
dy,
dy.stride(0),
dx,
dx.stride(0),
rstd,
_dw,
eps,
M,
N,
rows_per_sm,
block_N,
)
dw = _dw.sum(0).to(weight.dtype)
dx = dx.view(x_shape_start)
return dx, dw, None
# expose fusedRMSNorm as a function
def fused_rms_norm_fn(
x,
weight,
eps=1e-6,
):
return TritonFusedRMSNorm.apply(
x,
weight,
eps,
)
================================================
FILE: nit/models/utils/pos_embeds/flash_attn_rotary.py
================================================
# Copyright (c) 2023, Tri Dao.
import math
from typing import Optional, Tuple, Union
import torch
from einops import rearrange, repeat
from flash_attn.ops.triton.rotary import apply_rotary
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
dim=-1,
)
class ApplyRotaryEmb(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
out = apply_rotary(
x,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=inplace,
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.inplace = inplace
ctx.max_seqlen = max_seqlen
return out if not inplace else x
@staticmethod
def backward(ctx, do):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cu_seqlens = ctx.saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if not ctx.interleaved and not ctx.inplace:
do = do.clone()
dx = apply_rotary(
do,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=ctx.inplace,
conjugate=True,
)
return dx, None, None, None, None, None, None, None
def apply_rotary_emb(
x,
cos,
sin,
interleaved=False,
inplace=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmb.apply(
x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen
)
# For backward compatibility
apply_rotary_emb_func = apply_rotary_emb
#TODO need check ,whlzy modified!!!!
class ApplyRotaryEmbQKV_(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
total, three, nheads, headdim = qkv.shape # (total, 3, nheads, headdim)
assert three == 3
if cos_k is None and sin_k is None and qkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
# qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
qk = qkv[:, :2].reshape(total, -1, headdim)
apply_rotary(
qk,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=True
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
q, k = qkv[:, 0], qkv[:, 1]
apply_rotary(
q,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=True
)
apply_rotary(
k,
cos_k,
sin_k,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=interleaved,
inplace=True
)
ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
ctx.max_seqlen = max_seqlen
return qkv
@staticmethod
def backward(ctx, dqkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
else:
cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors
if cos_k is None and sin_k is None and dqkv.is_contiguous():
# Call 1 kernel instead of 2 kernels
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
dqk = rearrange(dqkv[:, :, :2], "b t h d -> b (t h) d") # b for total
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
else:
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
dq, dk = dqkv[:, 0], dqkv[:, 1]
apply_rotary(
dq,
cos,
sin,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True
)
apply_rotary(
dk,
cos_k,
sin_k,
seqlen_offsets=seqlen_offsets,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dqkv, None, None, None, None, None, None, None, None
#TODO need check ,whlzy modified!!!!
def apply_rotary_emb_qkv_(
qkv,
cos,
sin,
cos_k=None,
sin_k=None,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen)
class ApplyRotaryEmbKV_(torch.autograd.Function):
@staticmethod
def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):
batch, seqlen, two, nheads, headdim = kv.shape
assert two == 2
k = kv[:, :, 0]
apply_rotary(
k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True
)
if isinstance(seqlen_offsets, int):
ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
ctx.seqlen_offsets = seqlen_offsets
else:
ctx.save_for_backward(cos, sin, seqlen_offsets)
ctx.seqlen_offsets = None
ctx.interleaved = interleaved
return kv
@staticmethod
def backward(ctx, dkv):
seqlen_offsets = ctx.seqlen_offsets
if seqlen_offsets is None:
cos, sin, seqlen_offsets = ctx.saved_tensors
else:
cos, sin = ctx.saved_tensors
apply_rotary(
dkv[:, :, 0],
cos,
sin,
seqlen_offsets=seqlen_offsets,
interleaved=ctx.interleaved,
inplace=True,
conjugate=True,
)
return dkv, None, None, None, None
apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
def apply_rotary_emb_kv_(
kv,
cos,
sin,
interleaved=False,
seqlen_offsets: Union[int, torch.Tensor] = 0,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(
self,
dim: int,
base=10000.0,
interleaved=False,
scale_base=None,
pos_idx_in_fp32=True,
device=None,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super().__init__()
self.dim = dim
self.base = float(base)
self.pos_idx_in_fp32 = pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = self._compute_inv_freq(device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.interleaved = interleaved
self.scale_base = scale_base
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale, persistent=False)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _compute_inv_freq(self, device=None):
return 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
)
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (
seqlen > self._seq_len_cached
or self._cos_cached is None
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())
):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
if self.pos_idx_in_fp32:
t = torch.arange(seqlen, device=device, dtype=torch.float32)
# We want fp32 here as well since inv_freq will be multiplied with t, and the output
# will be large. Having it in bf16 will lose a lot of precision and cause the
# cos & sin output to change significantly.
# We want to recompute self.inv_freq if it was not loaded in fp32
if self.inv_freq.dtype != torch.float32:
inv_freq = self._compute_inv_freq(device=device)
else:
inv_freq = self.inv_freq
else:
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
inv_freq = self.inv_freq
# Don't do einsum, it converts fp32 to fp16 under AMP
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, inv_freq)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
def forward(
self,
qkv: torch.Tensor,
kv: Optional[torch.Tensor] = None,
seqlen_offset: Union[int, torch.Tensor] = 0,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = qkv.shape[1]
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
return apply_rotary_emb_qkv_(
qkv,
self._cos_cached,
self._sin_cached,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
q = qkv
q = apply_rotary_emb_func(
q,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
inplace=True,
seqlen_offsets=seqlen_offset,
)
if self.scale is None:
kv = apply_rotary_emb_kv_(
kv,
self._cos_cached,
self._sin_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
else:
kv = apply_rotary_emb_kv_(
kv,
self._cos_k_cached,
self._sin_k_cached,
interleaved=self.interleaved,
seqlen_offsets=seqlen_offset,
)
return q, kv
================================================
FILE: nit/models/utils/pos_embeds/rope.py
================================================
# --------------------------------------------------------
# FiT: A Flexible Vision Transformer for Image Generation
#
# Based on the following repository
# https://github.com/lucidrains/rotary-embedding-torch
# https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope
# https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37
# --------------------------------------------------------
import math
from math import pi
from typing import Optional, Any, Union, Tuple
import torch
from torch import nn
from einops import rearrange, repeat
from functools import lru_cache
#################################################################################
# NTK Operations #
#################################################################################
def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) #Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 #Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def find_newbase_ntk(dim, base=10000, scale=1):
# Base change formula
return base * scale ** (dim / (dim-2))
def get_mscale(scale=torch.Tensor):
# if scale <= 1:
# return 1.0
# return 0.1 * math.log(scale) + 1.0
return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)
def get_proportion(L_test, L_train):
L_test = L_test * 2
return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))
# return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))
#################################################################################
# Rotate Q or K #
#################################################################################
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
#################################################################################
# Core Vision RoPE #
#################################################################################
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
head_dim: int, # embed dimension for each head
custom_freqs: str = 'normal',
theta: int = 10000,
online_rope: bool = False,
max_cached_len: int = 1024,
max_pe_len_h: Optional[int] = None,
max_pe_len_w: Optional[int] = None,
decouple: bool = False,
ori_max_pe_len: Optional[int] = None,
):
super().__init__()
dim = head_dim // 2
assert dim % 2 == 0 # accually, this is important
self.dim = dim
self.custom_freqs = custom_freqs.lower()
self.theta = theta
self.decouple = decouple
self.ori_max_pe_len = ori_max_pe_len
self.custom_freqs = custom_freqs.lower()
if not online_rope:
if self.custom_freqs in ['normal', 'scale1', 'scale2']:
freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
else:
if decouple:
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)
else:
max_pe_len = max(max_pe_len_h, max_pe_len_w)
freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)
self.register_buffer('freqs_h', freqs_h, persistent=False)
self.register_buffer('freqs_w', freqs_w, persistent=False)
if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None:
attn_factor = 1.0
scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale
self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation
self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)
self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)
freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)
freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)
self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False)
freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)
freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)
self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False)
def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
# scaling operations for extrapolation
assert isinstance(ori_max_pe_len, int)
# scale = max_pe_len / ori_max_pe_len
if not isinstance(max_pe_len, torch.Tensor):
max_pe_len = torch.tensor(max_pe_len)
scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale
if self.custom_freqs == 'linear': # equal to position interpolation
freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))
elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':
freqs = 1. / torch.pow(
find_newbase_ntk(dim, theta, scale).view(-1, 1),
(torch.arange(0, dim, 2).to(scale).float() / dim)
).squeeze()
elif self.custom_freqs == 'ntk-by-parts':
#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_0 = 1.25
beta_1 = 0.75
gamma_0 = 16
gamma_1 = 2
ntk_factor = 1
extrapolation_factor = 1
#Three RoPE extrapolation/interpolation methods
freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
freqs_ntk = 1. / torch.pow(
find_newbase_ntk(dim, theta, scale).view(-1, 1),
(torch.arange(0, dim, 2).to(scale).float() / dim)
).squeeze()
#Combine NTK and Linear
low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor
freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask
#Combine Extrapolation and NTK and Linear
low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor
freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask
elif self.custom_freqs == 'yarn':
#Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
#Do not change unless there is a good reason for doing so!
beta_fast = 32
beta_slow = 1
extrapolation_factor = 1
freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))
freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))
low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)
freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask
else:
raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')
return freqs
def online_get_2d_rope_from_grid(self, grid, size):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
size: (B, 1, 2), h goes first and w goes last
'''
size = size.squeeze() # (B, 1, 2) -> (B, 2)
if self.decouple:
size_h = size[:, 0]
size_w = size[:, 1]
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)
else:
size_max = torch.max(size[:, 0], size[:, 1])
freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)
freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale[:, None, None]
freqs_sin = freqs.sin() * self.mscale[:, None, None]
elif self.custom_freqs == 'ntk-aware-pro1':
freqs_cos = freqs.cos() * self.proportion1[:, None, None]
freqs_sin = freqs.sin() * self.proportion1[:, None, None]
elif self.custom_freqs == 'ntk-aware-pro2':
freqs_cos = freqs.cos() * self.proportion2[:, None, None]
freqs_sin = freqs.sin() * self.proportion2[:, None, None]
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
@lru_cache()
def get_2d_rope_from_grid(self, grid):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
'''
freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h)
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w)
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale
freqs_sin = freqs.sin() * self.mscale
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
freqs_cos = freqs.cos() * self.proportion1
freqs_sin = freqs.sin() * self.proportion1
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
freqs_cos = freqs.cos() * self.proportion2
freqs_sin = freqs.sin() * self.proportion2
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
@lru_cache()
def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
'''
if len(grid.shape) == 3: # (B, 2, N)
freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]]
elif len(grid.shape) == 2: # (2, N)
freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]]
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale
freqs_sin = freqs.sin() * self.mscale
elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:
freqs_cos = freqs.cos() * self.proportion1
freqs_sin = freqs.sin() * self.proportion1
elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:
freqs_cos = freqs.cos() * self.proportion2
freqs_sin = freqs.sin() * self.proportion2
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
@lru_cache()
def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d rope formulation 2 !
'''
grid: (B, 3, N)
N = H * W * T
the first dimension represents width, and the second reprensents height, and the third reprensents time
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
'''
freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]+grid[:, 2]], self.freqs_h_cached[grid[:, 1]+grid[:, 2]]
freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D)
if self.custom_freqs == 'yarn':
freqs_cos = freqs.cos() * self.mscale
freqs_sin = freqs.sin() * self.mscale
elif self.custom_freqs == 'ntk-aware-pro1':
freqs_cos = freqs.cos() * self.proportion1
freqs_sin = freqs.sin() * self.proportion1
elif self.custom_freqs == 'ntk-aware-pro2':
freqs_cos = freqs.cos() * self.proportion2
freqs_sin = freqs.sin() * self.proportion2
else:
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
def forward(self, x, grid):
'''
x: (B, n_head, N, D)
grid: (B, 2, N)
'''
# freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)
# freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
# using cache to accelerate, this is the same with the above codes:
freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)
freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)
return x * freqs_cos + rotate_half(x) * freqs_sin
================================================
FILE: nit/models/utils/pos_embeds/sincos.py
================================================
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
import torch
import numpy as np
from einops import rearrange
import torch.nn.functional as F
def get_2d_sincos_pos_embed(embed_dim, h, w, frac_coord_size=None, scale_ratio=1.0, cls_token=False, extra_tokens=0):
"""
args:
h / w: int of the grid height / width
frac_coord_size:
if frac_coord_size != None:
fractional coordinates for positional embedding is used
else:
absolute coordinates for positional embedding is used
return:
pos_embed: [h*w, embed_dim] or [1+h*w, embed_dim] (w/ or w/o cls_token)
"""
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)
grid = torch.meshgrid(grid_w, grid_h, indexing='xy') # here w goes first
grid = torch.stack(grid, dim=0)
grid = rearrange(grid, '... -> 1 ...') # (1, 2, h*w)
pos_embed = get_2d_sincos_pos_embed_from_grid(
grid, embed_dim, frac_coord_size, scale_ratio
) # 1, L, D
if cls_token and extra_tokens > 0:
pos_embed = torch.cat([torch.zeros((1, extra_tokens, embed_dim)), pos_embed], dim=1)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):
'''
grid: (B, 2, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
frac_coord_size:
if frac_coord_size != None:
fractional coordinates for positional embedding is used
else:
absolute coordinates for positional embedding is used
'''
assert embed_dim % 2 == 0
grid = grid.float()
if frac_coord_size != None:
assert isinstance(frac_coord_size, (int, float))
grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size
grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size
else:
grid_w, grid_h = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio
# use half of dimensions to encode grid_h
emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, embed_dim // 2) # (B, N, D/2)
emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, embed_dim // 2) # (B, N, D/2)
emb = torch.cat([emb_h, emb_w], dim=-1) # (B, L, D)
return emb
def get_1d_sincos_pos_embed_from_grid(pos, embed_dim):
"""
embed_dim: output dimension for each position
pos: a batch of list whose positions to be encoded: size (B, N)
out: (B, N, D)
"""
assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
out = torch.einsum('BL,D->BLD', pos, omega.to(pos)) # (B, N, D/2), outer product
emb_sin = torch.sin(out) # (B, N, D/2)
emb_cos = torch.cos(out) # (B, N, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=-1) # (B, N, D)
return emb
def get_3d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0, time_dim=0):
'''
grid: (B, 3, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
frac_coord_size:
if frac_coord_size != None:
fractional coordinates for positional embedding is used
else:
absolute coordinates for positional embedding is used
'''
# assert embed_dim % 2 == 0
if time_dim == 0:
assert embed_dim % 3 == 0
dim = embed_dim // 3
time_dim = dim
else:
assert (embed_dim - time_dim) % 2 == 0
dim = (embed_dim - time_dim) // 2
grid = grid.float()
if frac_coord_size != None:
assert isinstance(frac_coord_size, (int, float))
grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size
grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size
grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size
else:
grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio
# use half of dimensions to encode grid_h
emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim) # (B, N, D/2)
emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim) # (B, N, D/2)
emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, time_dim) # (B, N, D/2)
emb = torch.cat([emb_t, emb_h, emb_w], dim=-1) # (B, L, D)
return emb
def get_21d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):
'''
grid: (B, 3, N)
N = H * W
the first dimension represents width, and the second reprensents height
e.g., [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]
[0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]
frac_coord_size:
if frac_coord_size != None:
fractional coordinates for positional embedding is used
else:
absolute coordinates for positional embedding is used
'''
assert embed_dim % 2 == 0
dim = embed_dim // 2
grid = grid.float()
if frac_coord_size != None:
assert isinstance(frac_coord_size, (int, float))
grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size
grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size
grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size
else:
grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio
# use half of dimensions to encode grid_h
emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim)
emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim)
emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim)
emb = torch.cat([emb_h, emb_w], dim=-1) + emb_t # (B, L, D)
return emb
def get_time_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):
grid = grid.float()
grid_t = grid[:, 0]*scale_ratio
emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim)
return emb_t
#################################################################################
# interpolation #
#################################################################################
def interpolate_sincos_pos_embed(embed_dim, ori_h, ori_w, tgt_h, tgt_w):
from src.inf.models.dit import get_2d_sincos_pos_embed
pos_embed = get_2d_sincos_pos_embed(embed_dim, ori_h, ori_w)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
pos_embed = rearrange(pos_embed, '1 (h w) d -> 1 d h w', h=ori_h, w=ori_w)
pos_embed = F.interpolate(pos_embed, (tgt_h, tgt_w), mode='bilinear')
pos_embed = rearrange(pos_embed, '1 d h w -> 1 (h w) d')
return pos_embed
def interpolate_sincos_pos_index(embed_dim, ori_h, ori_w, tgt_h, tgt_w):
from src.inf.models.dit import get_2d_sincos_pos_embed_from_grid
grid_h = np.arange(tgt_h, dtype=np.float32) * ori_h / tgt_h
grid_w = np.arange(tgt_w, dtype=np.float32) * ori_w / tgt_w
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, tgt_h, tgt_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
return pos_embed
================================================
FILE: nit/schedulers/flow_matching/loss.py
================================================
import torch
import numpy as np
import torch.nn.functional as F
def mean_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return torch.mean(x, dim=list(range(1, len(x.size()))))
def sum_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return torch.sum(x, dim=list(range(1, len(x.size()))))
class FlowMatchingLoss:
def __init__(
self,
prediction='v',
path_type="linear",
weighting="uniform",
encoders=[],
accelerator=None,
latents_scale=None,
latents_bias=None,
P_mean=0.0,
P_std=1.0,
sigma_data=1.0,
unit_variance=False,
):
self.prediction = prediction
self.weighting = weighting
self.path_type = path_type
self.encoders = encoders
self.accelerator = accelerator
self.latents_scale = latents_scale
self.latents_bias = latents_bias
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
self.unit_variance = unit_variance
def interpolant(self, t):
if self.path_type == "linear":
alpha_t = 1 - t
sigma_t = t
d_alpha_t = -1
d_sigma_t = 1
elif self.path_type == "cosine":
alpha_t = torch.cos(t * torch.pi / 2)
sigma_t = torch.sin(t * torch.pi / 2)
d_alpha_t = -torch.pi / 2 * torch.sin(t * torch.pi / 2)
d_sigma_t = torch.pi / 2 * torch.cos(t * torch.pi / 2)
elif self.path_type == 'triangle':
alpha_t = torch.cos(t)
sigma_t = torch.sin(t)
d_alpha_t = -torch.sin(t)
d_sigma_t = torch.cos(t)
else:
raise NotImplementedError()
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def __call__(self, model, batch_size, images, noises, model_kwargs=None, use_dir_loss=False, zs=[]):
if model_kwargs == None:
model_kwargs = {}
# sample timestep according to log-normal distribution of sigmas following EDM
rnd_normal = torch.randn((batch_size))
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
if self.path_type == "linear": # [0, 1]
t = sigma / (1 + sigma)
elif self.path_type == "cosine": # [0, 1]
t = 2 / np.pi * torch.atan(sigma)
elif self.path_type == 'triangle': # [0, pi/2]
t = torch.atan(sigma / self.sigma_data)
else:
raise NotImplementedError
t = t.to(device=images.device, dtype=images.dtype)
time_input = t
hw_list = model_kwargs['hw_list']
seqlens = hw_list[:, 0] * hw_list[:, 1]
t = torch.cat([t[i].unsqueeze(0).repeat(seqlens[i], 1, 1, 1) for i in range(batch_size)], dim=0)
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(t)
if self.unit_variance:
model_input = alpha_t * images / self.sigma_data + sigma_t * noises
else:
model_input = alpha_t * images + sigma_t * noises
if self.prediction == 'v':
model_target = d_alpha_t * images + d_sigma_t * noises
else:
raise NotImplementedError() # TODO: add x or eps prediction
model_kwargs['return_zs'] = True
if self.unit_variance:
model_output, zs_tilde = self.sigma_data * model(model_input, time_input, **model_kwargs)
else:
model_output, zs_tilde = model(model_input, time_input, **model_kwargs)
denoising_loss = mean_flat((model_output - model_target) ** 2)
denoising_loss = torch.nan_to_num(denoising_loss, nan=0, posinf=1e5, neginf=-1e5)
loss = denoising_loss.mean()
if use_dir_loss:
directional_loss = mean_flat(1 - F.cosine_similarity(model_output, model_target, dim=1))
directional_loss = torch.nan_to_num(directional_loss, nan=0, posinf=1e5, neginf=-1e5)
loss += directional_loss.mean()
proj_loss = 0.
if zs != [] and zs != None:
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
proj_loss += 1 - torch.cosine_similarity(z, z_tilde, dim=-1).mean()
proj_loss = torch.nan_to_num(proj_loss, nan=0, posinf=1e5, neginf=-1e5)
return loss, proj_loss
================================================
FILE: nit/schedulers/flow_matching/samplers_c2i.py
================================================
import torch
import numpy as np
def expand_t_like_x(t, x_cur, hw_list):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * (len(x_cur.size()) - 1)
seqlens = hw_list[:, 0] * hw_list[:, 1]
B = t.shape[0]
t = torch.cat([t[i].unsqueeze(0).repeat(int(seqlens[i]), *dims) for i in range(B)], dim=0)
return t
def get_score_from_velocity(vt, xt, t, hw_list, path_type="linear"):
"""Wrapper function: transfrom velocity prediction model to score
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, xt, hw_list)
if path_type == "linear":
alpha_t, d_alpha_t = 1 - t, torch.ones_like(t, device=t.device) * -1
sigma_t, d_sigma_t = t, torch.ones_like(t, device=t.device)
elif path_type == "cosine":
alpha_t = torch.cos(t * np.pi / 2)
sigma_t = torch.sin(t * np.pi / 2)
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
else:
raise NotImplementedError
mean = xt
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * vt - mean) / var
return score
def compute_diffusion(t_cur):
return 2 * t_cur
def euler_sampler(
model,
ag_model,
latents,
y,
hw_list,
num_steps=20,
heun=False,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear", # not used, just for compatability
):
# setup conditioning
if cfg_scale > 1.0:
y_null = torch.tensor([1000] * y.size(0), device=y.device)
if ag_model != None:
auto_guidance = True
else:
auto_guidance = False
_dtype = latents.dtype
t_steps = torch.linspace(1, 0, num_steps+1, dtype=torch.float64)
x_next = latents.to(torch.float64)
device = x_next.device
with torch.no_grad():
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
x_cur = x_next
if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
hw_list_cur = torch.cat([hw_list, hw_list], dim=0)
else:
model_input = x_cur
y_cur = y
hw_list_cur = hw_list
kwargs = dict(y=y_cur, hw_list=hw_list_cur)
time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur
d_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
if auto_guidance:
kwargs = dict(y=y_null, hw_list=hw_list_cur)
time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur
d_cur_uncond = ag_model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond)
else:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
x_next = x_cur + (t_next - t_cur) * d_cur
if heun and (i < num_steps - 1):
if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_next] * 2)
y_cur = torch.cat([y, y_null], dim=0)
hw_list_cur = torch.cat([hw_list, hw_list], dim=0)
else:
model_input = x_next
y_cur = y
hw_list_cur = hw_list
kwargs = dict(y=y_cur, hw_list=hw_list_cur)
time_input = torch.ones(y_cur.size(0)).to(
device=model_input.device, dtype=torch.float64
) * t_next
d_prime = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
if auto_guidance:
kwargs = dict(y=y_null, hw_list=hw_list_cur)
time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_next
d_prime_uncond = ag_model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
d_prime = d_prime_uncond + cfg_scale * (d_prime - d_prime_uncond)
else:
d_prime_cond, d_prime_uncond = d_prime.chunk(2)
d_prime = d_prime_uncond + cfg_scale * (d_prime_cond - d_prime_uncond)
x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
def euler_maruyama_sampler(
model,
ag_model,
latents,
y,
hw_list,
num_steps=20,
heun=False, # not used, just for compatability
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
):
# setup conditioning
if cfg_scale > 1.0:
y_null = torch.tensor([1000] * y.size(0), device=y.device)
if ag_model != None:
auto_guidance = True
else:
auto_guidance = False
_dtype = latents.dtype
t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64)
t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)])
x_next = latents.to(torch.float64)
device = x_next.device
with torch.no_grad():
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
dt = t_next - t_cur
x_cur = x_next
if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
hw_list_cur = torch.cat([hw_list, hw_list], dim=0)
else:
model_input = x_cur
y_cur = y
hw_list_cur = hw_list
kwargs = dict(y=y_cur, hw_list=hw_list_cur)
time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
eps_i = torch.randn_like(x_cur).to(device)
deps = eps_i * torch.sqrt(torch.abs(dt))
# compute drift
v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type)
d_cur = v_cur - 0.5 * diffusion * s_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
if auto_guidance:
kwargs = dict(y=y_null, hw_list=hw_list_cur)
time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
eps_i = torch.randn_like(x_cur).to(device)
deps = eps_i * torch.sqrt(torch.abs(dt))
# compute drift
v_cur_uncond = ag_model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type)
d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond
d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond)
else:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
# last step
t_cur, t_next = t_steps[-2], t_steps[-1]
dt = t_next - t_cur
x_cur = x_next
if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
hw_list_cur = torch.cat([hw_list, hw_list], dim=0)
else:
model_input = x_cur
y_cur = y
hw_list_cur = hw_list
kwargs = dict(y=y_cur, hw_list=hw_list_cur)
time_input = torch.ones(y_cur.size(0)).to(
device=device, dtype=torch.float64
) * t_cur
# compute drift
v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type)
diffusion = compute_diffusion(t_cur)
d_cur = v_cur - 0.5 * diffusion * s_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
if auto_guidance:
kwargs = dict(y=y_null, hw_list=hw_list_cur)
time_input = torch.ones(y_null.size(0)).to(
device=device, dtype=torch.float64
) * t_cur
# compute drift
v_cur_uncond = ag_model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
).to(torch.float64)
s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type)
diffusion = compute_diffusion(t_cur)
d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond
d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond)
else:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
mean_x = x_cur + dt * d_cur
return mean_x
================================================
FILE: nit/utils/__init__.py
================================================
from .misc_utils import *
from .train_utils import *
from .eval_utils import *
from .gpu_memory_monitor import *
================================================
FILE: nit/utils/deepspeed_zero_to_fp32.py
================================================
#!/usr/bin/env python
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
#
# example: python zero_to_fp32.py . pytorch_model.bin
import argparse
import torch
import glob
import math
import os
import re
from collections import OrderedDict
from dataclasses import dataclass
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
# DeepSpeed data structures it has to be available in the current python environment.
from deepspeed.utils import logger
from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
@dataclass
class zero_model_state:
buffers: dict()
param_shapes: dict()
shared_params: list
ds_version: int
frozen_param_shapes: dict()
frozen_param_fragments: dict()
debug = 0
# load to cpu
device = torch.device('cpu')
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split(r'(\d+)', text)]
def get_model_state_file(checkpoint_dir, zero_stage):
if not os.path.isdir(checkpoint_dir):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
# there should be only one file
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
if not os.path.exists(file):
raise FileNotFoundError(f"can't find model states file at '{file}'")
return file
def get_checkpoint_files(checkpoint_dir, glob_pattern):
# XXX: need to test that this simple glob rule works for multi-node setup too
ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
if len(ckpt_files) == 0:
raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
return ckpt_files
def get_optim_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
def get_model_state_files(checkpoint_dir):
return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
def parse_model_states(files):
zero_model_states = []
for file in files:
state_dict = torch.load(file, map_location=device)
if BUFFER_NAMES not in state_dict:
raise ValueError(f"{file} is not a model state checkpoint")
buffer_names = state_dict[BUFFER_NAMES]
if debug:
print("Found buffers:", buffer_names)
# recover just the buffers while restoring them to fp32 if they were saved in fp16
buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
param_shapes = state_dict[PARAM_SHAPES]
# collect parameters that are included in param_shapes
param_names = []
for s in param_shapes:
for name in s.keys():
param_names.append(name)
# update with frozen parameters
frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
if frozen_param_shapes is not None:
if debug:
print(f"Found frozen_param_shapes: {frozen_param_shapes}")
param_names += list(frozen_param_shapes.keys())
# handle shared params
shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
ds_version = state_dict.get(DS_VERSION, None)
frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
z_model_state = zero_model_state(buffers=buffers,
param_shapes=param_shapes,
shared_params=shared_params,
ds_version=ds_version,
frozen_param_shapes=frozen_param_shapes,
frozen_param_fragments=frozen_param_fragments)
zero_model_states.append(z_model_state)
return zero_model_states
def parse_optim_states(files, ds_checkpoint_dir):
total_files = len(files)
state_dicts = []
for f in files:
state_dict = torch.load(f, map_location=device)
# immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
# and also handle the case where it was already removed by another helper script
state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
state_dicts.append(state_dict)
if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
raise ValueError(f"{files[0]} is not a zero checkpoint")
zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
# parameters can be different from data parallelism for non-expert parameters. So we can just
# use the max of the partition_count to get the dp world_size.
if type(world_size) is list:
world_size = max(world_size)
if world_size != total_files:
raise ValueError(
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
)
# the groups are named differently in each stage
if zero_stage <= 2:
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3:
fp32_groups_key = FP32_FLAT_GROUPS
else:
raise ValueError(f"unknown zero stage {zero_stage}")
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
# flattened tensor per group - for simplicity merge them into a single tensor
#
# XXX: could make the script more memory efficient for when there are multiple groups - it
# will require matching the sub-lists of param_shapes for each param group flattened tensor
fp32_flat_groups = [
torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
]
return zero_stage, world_size, fp32_flat_groups
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
"""
Returns fp32 state_dict reconstructed from ds checkpoint
Args:
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
"""
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
optim_files = get_optim_files(ds_checkpoint_dir)
zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
model_files = get_model_state_files(ds_checkpoint_dir)
zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters)
def _zero2_merge_frozen_params(state_dict, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
frozen_param_fragments = zero_model_states[0].frozen_param_fragments
if debug:
num_elem = sum(s.numel() for s in frozen_param_shapes.values())
print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
state_dict[name] = frozen_param_fragments[name]
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _has_callable(obj, fn):
attr = getattr(obj, fn, None)
return callable(attr)
def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
# Reconstruction protocol:
#
# XXX: document this
if debug:
for i in range(world_size):
for j in range(len(fp32_flat_groups[0])):
print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
# XXX: memory usage doubles here (zero2)
num_param_groups = len(fp32_flat_groups[0])
merged_single_partition_of_fp32_groups = []
for i in range(num_param_groups):
merged_partitions = [sd[i] for sd in fp32_flat_groups]
full_single_fp32_vector = torch.cat(merged_partitions, 0)
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
avail_numel = sum(
[full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
if debug:
wanted_params = sum([len(shapes) for shapes in param_shapes])
wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
# not asserting if there is a mismatch due to possible padding
print(f"Have {avail_numel} numels to process.")
print(f"Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
total_numel = 0
total_params = 0
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
offset = 0
avail_numel = full_single_fp32_vector.numel()
for name, shape in shapes.items():
unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
total_numel += unpartitioned_numel
total_params += 1
if debug:
print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
offset += unpartitioned_numel
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
# live optimizer object, so we are checking that the numbers are within the right range
align_to = 2 * world_size
def zero2_align(x):
return align_to * math.ceil(x / align_to)
if debug:
print(f"original offset={offset}, avail_numel={avail_numel}")
offset = zero2_align(offset)
avail_numel = zero2_align(avail_numel)
if debug:
print(f"aligned offset={offset}, avail_numel={avail_numel}")
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
if not exclude_frozen_parameters:
_zero2_merge_frozen_params(state_dict, zero_model_states)
_zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
remainder = unpartitioned_numel % world_size
padding_numel = (world_size - remainder) if remainder else 0
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
return partitioned_numel, padding_numel
def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
return
if debug:
for i in range(world_size):
num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
frozen_param_shapes = zero_model_states[0].frozen_param_shapes
wanted_params = len(frozen_param_shapes)
wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
print(f'Frozen params: Have {avail_numel} numels to process.')
print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
total_params = 0
total_numel = 0
for name, shape in zero_model_states[0].frozen_param_shapes.items():
total_params += 1
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
param_shapes = zero_model_states[0].param_shapes
avail_numel = fp32_flat_groups[0].numel() * world_size
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
# param, re-consolidating each param, while dealing with padding if any
# merge list of dicts, preserving order
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
if debug:
for i in range(world_size):
print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
wanted_params = len(param_shapes)
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
# not asserting if there is a mismatch due to possible padding
avail_numel = fp32_flat_groups[0].numel() * world_size
print(f"Trainable params: Have {avail_numel} numels to process.")
print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
# params
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
# out-of-core computing solution
offset = 0
total_numel = 0
total_params = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
total_numel += unpartitioned_numel
total_params += 1
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
if debug:
print(
f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
)
# XXX: memory usage doubles here
state_dict[name] = torch.cat(
tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
0).narrow(0, 0, unpartitioned_numel).view(shape)
offset += partitioned_numel
offset *= world_size
# Sanity check
if offset != avail_numel:
raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
exclude_frozen_parameters):
state_dict = OrderedDict()
# buffers
buffers = zero_model_states[0].buffers
state_dict.update(buffers)
if debug:
print(f"added {len(buffers)} buffers")
if not exclude_frozen_parameters:
_zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
_zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
# recover shared parameters
for pair in zero_model_states[0].shared_params:
if pair[1] in state_dict:
state_dict[pair[0]] = state_dict[pair[1]]
return state_dict
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
via a model hub.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
Returns:
- pytorch ``state_dict``
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
the checkpoint.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
# do the training and checkpoint saving
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
model = model.cpu() # move to cpu
model.load_state_dict(state_dict)
# submit to model hub or save the model to share with others
In this example the ``model`` will no longer be usable in the deepspeed context of the same
application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
"""
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(ds_checkpoint_dir):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
- ``exclude_frozen_parameters``: exclude frozen parameters
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
print(f"Saving fp32 state dict to {output_file}")
torch.save(state_dict, output_file)
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
3. Load it into the provided model
Args:
- ``model``: the model object to update
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
Returns:
- ``model`: modified model
Make sure you have plenty of CPU memory available before you call this function. If you don't
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
conveniently placed for you in the checkpoint folder.
A typical usage might be ::
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
# submit to model hub or save the model to share with others
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
"""
logger.info(f"Extracting fp32 weights")
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
logger.info(f"Overwriting model with fp32 weights")
model = model.cpu()
model.load_state_dict(state_dict, strict=False)
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_dir",
type=str,
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
parser.add_argument(
"output_file",
type=str,
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
parser.add_argument("-t",
"--tag",
type=str,
default=None,
help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
args = parser.parse_args()
debug = args.debug
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
args.output_file,
tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters)
================================================
FILE: nit/utils/ema.py
================================================
import torch
from collections import OrderedDict
from copy import deepcopy
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
if hasattr(model, 'module'):
model = model.module
if hasattr(ema_model, 'module'):
ema_model = ema_model.module
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
================================================
FILE: nit/utils/eval_utils.py
================================================
from PIL import Image
import numpy as np
from tqdm import tqdm
import torch
import re
import os
from safetensors.torch import load_file
def create_npz_from_sample_folder(sample_dir, num=50_000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0]))
print(len(imgs))
assert len(imgs) >= num
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{imgs[i]}")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def init_from_ckpt(
model, checkpoint_dir, ignore_keys=None, verbose=False
) -> None:
if checkpoint_dir.endswith(".safetensors"):
model_state_dict=load_file(checkpoint_dir, device='cpu')
else:
model_state_dict=torch.load(checkpoint_dir, map_location="cpu")
model_new_ckpt=dict()
for i in model_state_dict.keys():
model_new_ckpt[i] = model_state_dict[i]
keys = list(model_new_ckpt.keys())
for k in keys:
if ignore_keys:
for ik in ignore_keys:
if ik in k:
print("Deleting key {} from state_dict.".format(k))
del model_new_ckpt[k]
missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False)
if verbose:
print(
f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
if verbose:
print("")
def none_or_str(value):
if value == 'None':
return None
return value
def parse_sde_args(parser):
group = parser.add_argument_group("SDE arguments")
group.add_argument("--sde-sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
group.add_argument("--diffusion-form", type=str, default="sigma", \
choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
help="form of diffusion coefficient in the SDE")
group.add_argument("--diffusion-norm", type=float, default=1.0)
group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
help="form of last step taken in the SDE")
group.add_argument("--last-step-size", type=float, default=0.04, \
help="size of the last step taken")
def parse_ode_args(parser):
group = parser.add_argument_group("ODE arguments")
group.add_argument("--ode-sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
group.add_argument("--reverse", action="store_true")
group.add_argument("--likelihood", action="store_true")
# ode solvers:
# - Adaptive-step:
# - dopri8 Runge-Kutta 7(8) of Dormand-Prince-Shampine
# - dopri5 Runge-Kutta 4(5) of Dormand-Prince [default].
# - bosh3 Runge-Kutta 2(3) of Bogacki-Shampine
# - adaptive_heun Runge-Kutta 1(2)
# - Fixed-step:
# - euler Euler method.
# - midpoint Midpoint method.
# - rk4 Fourth-order Runge-Kutta with 3/8 rule.
# - explicit_adams Explicit Adams.
# - implicit_adams Implicit Adams.
================================================
FILE: nit/utils/freeze.py
================================================
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def freeze_model(model, trainable_modules={}, verbose=False):
logger.info("Start freeze")
for name, param in model.named_parameters():
param.requires_grad = False
if verbose:
logger.info("freeze moduel: "+str(name))
for trainable_module_name in trainable_modules:
if trainable_module_name in name:
param.requires_grad = True
if verbose:
logger.info("unfreeze moduel: "+str(name))
break
logger.info("End freeze")
params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
return
================================================
FILE: nit/utils/gpu_memory_monitor.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
from collections import namedtuple
from datetime import datetime
from typing import Any, Dict, Optional
import torch
# named tuple for passing GPU memory stats for logging
GPUMemStats = namedtuple(
"GPUMemStats",
[
"max_active_gib",
"max_active_pct",
"max_reserved_gib",
"max_reserved_pct",
"num_alloc_retries",
"num_ooms",
],
)
class GPUMemoryMonitor:
def __init__(self, logger, device: str = "cuda:0"):
self.device = torch.device(device) # device object
self.device_name = torch.cuda.get_device_name(self.device)
self.device_index = torch.cuda.current_device()
self.device_capacity = torch.cuda.get_device_properties(
self.device
).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
self.logger = logger
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib
def _to_pct(self, memory):
return 100 * memory / self.device_capacity
def get_peak_stats(self):
cuda_info = torch.cuda.memory_stats(self.device)
max_active = cuda_info["active_bytes.all.peak"]
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)
max_reserved = cuda_info["reserved_bytes.all.peak"]
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)
num_retries = cuda_info["num_alloc_retries"]
num_ooms = cuda_info["num_ooms"]
if num_retries > 0:
self.logger.warning(f"{num_retries} CUDA memory allocation retries.")
if num_ooms > 0:
self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.")
return GPUMemStats(
max_active_gib,
max_active_pct,
max_reserved_gib,
max_reserved_pct,
num_retries,
num_ooms,
)
def reset_peak_stats(self):
torch.cuda.reset_peak_memory_stats()
def build_gpu_memory_monitor(logger):
gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda")
logger.info(
f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) "
f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory"
)
return gpu_memory_monitor
================================================
FILE: nit/utils/lr_scheduler.py
================================================
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for diffusion models."""
import math
from enum import Enum
from typing import Optional, Union
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class SchedulerType(Enum):
LINEAR = "linear"
COSINE = "cosine"
COSINE_WITH_RESTARTS = "cosine_with_restarts"
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
PIECEWISE_CONSTANT = "piecewise_constant"
WARMDUP_STABLE_DECAY = "warmup_stable_decay"
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
def get_constant_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, div_factor: int = 1e-4, last_epoch: int = -1
):
def lr_lambda(current_step):
# 0,y0 step,y1
#((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1
if current_step < num_warmup_steps:
return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
step_rules (`string`):
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
steps and multiple 0.005 for the other steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
rules_dict = {}
rule_list = step_rules.split(",")
for rule_str in rule_list[:-1]:
value_str, steps_str = rule_str.split(":")
steps = int(steps_str)
value = float(value_str)
rules_dict[steps] = value
last_lr_multiple = float(rule_list[-1])
def create_rules_function(rules_dict, last_lr_multiple):
def rule_func(steps: int) -> float:
sorted_steps = sorted(rules_dict.keys())
for i, sorted_step in enumerate(sorted_steps):
if steps < sorted_step:
return rules_dict[sorted_steps[i]]
return last_lr_multiple
return rule_func
rules_func = create_rules_function(rules_dict, last_lr_multiple)
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_periods (`float`, *optional*, defaults to 0.5):
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
value to 0 following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
linearly between 0 and the initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`int`, *optional*, defaults to 1):
The number of hard restarts to use.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
if progress >= 1.0:
return 0.0
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_polynomial_decay_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
lr_end (`float`, *optional*, defaults to 1e-7):
The end LR.
power (`float`, *optional*, defaults to 1.0):
Power factor.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
implementation at
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_init = optimizer.defaults["lr"]
if not (lr_init > lr_end):
raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step > num_training_steps:
return lr_end / lr_init # as LambdaLR multiplies by lr_init
else:
lr_range = lr_init - lr_end
decay_steps = num_training_steps - num_warmup_steps
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
decay = lr_range * pct_remaining**power + lr_end
return decay / lr_init # as LambdaLR multiplies by lr_init
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_constant_schedule_with_warmup_and_decay(
optimizer: Optimizer, num_warmup_steps: int, num_decay_steps: int, decay_T: int = 50000, div_factor: int = 1e-4, last_epoch: int = -1
):
def lr_lambda(current_step):
# 0,y0 step,y1
#((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1
if current_step < num_warmup_steps:
return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor
if current_step > num_decay_steps:
return 0.5 ** ((current_step - num_decay_steps) / decay_T)
return 1.0
return LambdaLR(optimizer, lr_lambda, last_epoch)
TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
SchedulerType.WARMDUP_STABLE_DECAY: get_constant_schedule_with_warmup_and_decay
}
def get_scheduler(
name: Union[str, SchedulerType],
optimizer: Optimizer,
step_rules: Optional[str] = None,
num_warmup_steps: Optional[int] = None,
num_decay_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
decay_T: Optional[int] = 50000,
power: float = 1.0,
last_epoch: int = -1,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
step_rules (`str`, *optional*):
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_decay_steps (`int`, *optional*):
The number of decay steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
decay_T (`int`, *optional*, defaults to 50000):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, last_epoch=last_epoch)
if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
if name == SchedulerType.WARMDUP_STABLE_DECAY:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_decay_steps=num_decay_steps, decay_T=decay_T, last_epoch=last_epoch)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
last_epoch=last_epoch,
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
power=power,
last_epoch=last_epoch,
)
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
)
================================================
FILE: nit/utils/misc_utils.py
================================================
import functools
import importlib
import os
import wandb
import fsspec
import numpy as np
import torch
from dataclasses import dataclass
from functools import partial
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
def get_dtype(str_dtype):
if str_dtype == 'fp16':
return torch.float16
elif str_dtype == 'bf16':
return torch.bfloat16
else:
return torch.float32
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == "(" and s[-1] == ")":
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
elif ckpt.endswith("bin"):
sd = torch.load(ckpt, map_location="cpu")
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def format_number(num):
num = float(num)
num /= 1000.0
return '{:.0f}{}'.format(num, 'k')
def get_num_params(model: torch.nn.ModuleList) -> int:
num_params = sum(p.numel() for p in model.parameters())
return num_params
def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
l, h, q, t = (
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
flop_per_token = 6 * num_params + 12 * l * h * q * t
return flop_per_token
def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
l, h, q = (
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
)
# 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
# 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
# 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len
return flop_per_sequence
# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 1979e12
elif "PCIe" in device_name:
return 756e12
else: # for SXM and other variants
return 989e12
else: # for other GPU types, assume A100
return 312e12
@dataclass(frozen=True)
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"
@dataclass(frozen=True)
class NoColor:
black = ""
red = ""
green = ""
yellow = ""
blue = ""
magenta = ""
cyan = ""
white = ""
reset = ""
================================================
FILE: nit/utils/model_utils.py
================================================
import os
import torch
from transformers import T5EncoderModel, AutoModelForCausalLM, AutoTokenizer
# dc-ae
def dc_ae_encode(dc_ae, images):
with torch.no_grad():
latents = dc_ae.encode(images).latent * dc_ae.config.scaling_factor
return latents
def dc_ae_decode(dc_ae, latents):
with torch.no_grad():
z = latents / dc_ae.config.scaling_factor
if dc_ae.use_slicing and z.size(0) > 1:
decoded_slices = [dc_ae._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = dc_ae._decode(z)
images = decoded # decoded images
return images
# sd-vae
def sd_vae_encode(sd_vae, images):
with torch.no_grad():
z = sd_vae.encode(images)
if isinstance(z, dict):
z=z.latent_dist.sample()
z = sd_vae.config.scaling_factor * z
return z
def sd_vae_decode(sd_vae, latents):
with torch.no_grad():
z = 1.0 / sd_vae.config.scaling_factor * latents
out = sd_vae.decode(z)
if isinstance(out, dict):
out=out.sample
return out
# load text-encoder
def load_text_encoder(text_encoder_dir, device, weight_dtype):
os.environ["TOKENIZERS_PARALLELISM"] = "true"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)
if 'gemma' in text_encoder_dir:
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_dir, attn_implementation="flash_attention_2", device_map='cpu', torch_dtype=weight_dtype
).get_decoder()
elif 't5' in text_encoder_dir:
text_encoder = T5EncoderModel.from_pretrained(
text_encoder_dir, attn_implementation="sdpa", device_map='cpu', torch_dtype=weight_dtype
)
else:
raise NotImplementedError
text_encoder.requires_grad_(False)
text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)
return text_encoder, tokenizer
def encode_prompt(tokenizer, text_encoder, device, weight_dtype, captions, use_last_hidden_state, max_seq_length=256):
text_inputs = tokenizer(
captions,
padding='max_length',
max_length=max_seq_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device)
with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype):
results = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
)
if use_last_hidden_state:
prompt_embeds = results.last_hidden_state
else: # from Imagen paper
prompt_embeds = results.hidden_states[-2]
return prompt_embeds, prompt_masks
def prepare_null_cap_feat_mask(text_encoder_type, device, weight_dtype, use_last_hidden_state, max_seq_length=256):
text_encoder, tokenizer = load_text_encoder(
text_encoder_dir=text_encoder_type, device=device, weight_dtype=weight_dtype
)
null_cap_features, null_cap_mask = encode_prompt(
tokenizer, text_encoder, device, weight_dtype,
"", use_last_hidden_state, max_seq_length=max_seq_length
)
return null_cap_features, null_cap_mask
================================================
FILE: nit/utils/train_utils.py
================================================
import torch
from collections import OrderedDict
from copy import deepcopy
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def freeze_model(model, trainable_modules={}, verbose=False):
logger.info("Start freeze")
for name, param in model.named_parameters():
param.requires_grad = False
if verbose:
logger.info("freeze moduel: "+str(name))
for trainable_module_name in trainable_modules:
if trainable_module_name in name:
param.requires_grad = True
if verbose:
logger.info("unfreeze moduel: "+str(name))
break
logger.info("End freeze")
params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]
params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]
logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M")
logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M")
return
@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Step the EMA model towards the current model.
"""
if hasattr(model, 'module'):
model = model.module
if hasattr(ema_model, 'module'):
ema_model = ema_model.module
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
for name, param in model_params.items():
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def log_validation(model):
pass
================================================
FILE: nit/utils/util.py
================================================
import functools
import importlib
import os
import wandb
import fsspec
import numpy as np
import torch
from dataclasses import dataclass
from functools import partial
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == "(" and s[-1] == ")":
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = "\n".join(
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == "file":
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
elif ckpt.endswith("bin"):
sd = torch.load(ckpt, map_location="cpu")
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def format_number(num):
num = float(num)
num /= 1000.0
return '{:.0f}{}'.format(num, 'k')
def get_num_params(model: torch.nn.ModuleList) -> int:
num_params = sum(p.numel() for p in model.parameters())
return num_params
def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
l, h, q, t = (
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
flop_per_token = 6 * num_params + 12 * l * h * q * t
return flop_per_token
def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:
l, h, q = (
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
)
# 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)
# 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)
# 3. 双向注意力需要考虑所有token对,所以是t^2而不是t
flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len
return flop_per_sequence
# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU
def get_peak_flops(device_name: str) -> int:
if "A100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/a100/
return 312e12
elif "H100" in device_name:
# data from https://www.nvidia.com/en-us/data-center/h100/
# NOTE: Specifications are one-half lower without sparsity.
if "NVL" in device_name:
return 1979e12
elif "PCIe" in device_name:
return 756e12
else: # for SXM and other variants
return 989e12
else: # for other GPU types, assume A100
return 312e12
@dataclass(frozen=True)
class Color:
black = "\033[30m"
red = "\033[31m"
green = "\033[32m"
yellow = "\033[33m"
blue = "\033[34m"
magenta = "\033[35m"
cyan = "\033[36m"
white = "\033[37m"
reset = "\033[39m"
@dataclass(frozen=True)
class NoColor:
black = ""
red = ""
green = ""
yellow = ""
blue = ""
magenta = ""
cyan = ""
white = ""
reset = ""
================================================
FILE: nit/utils/video_utils.py
================================================
import os
import cv2
import numpy as np
from PIL import Image
def save_video_as_mp4(video_array, fps, output_path):
# video_array: TCHW (RGB)
height, width = video_array.shape[2], video_array.shape[3]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for t in range(video_array.shape[0]):
frame = video_array[t].transpose(1, 2, 0)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # RGB->BGR
out.write(cv2.convertScaleAbs(frame))
out.release()
def save_video_as_png(video_array, output_path):
os.makedirs(output_path, exist_ok=True)
# video_array: TCHW (RGB)
for i, sample in enumerate(video_array):
sample = np.transpose(sample, (1, 2, 0))
Image.fromarray(sample).save( # HWC
os.path.join(output_path, f"{i:06d}.png")
)
================================================
FILE: nit/utils/warp_pos_idx.py
================================================
import torch
import random
from typing import Optional, Union
def warp_pos_idx_from_grid(
grid: torch.Tensor,
shift: Optional[int] = 0,
scale: Optional[str] = None,
max_len: Optional[Union[int, float]]=None
):
'''
grid: the 2-D positional index to be warped (B, 2, D)
shift: the max shift value for the positional indices
scale: the scale scheme for warping positional indices
max_len: the max scale length
'''
grid[:, 0] = warp_pos_idx(grid[:, 0], shift, scale, max_len)
grid[:, 1] = warp_pos_idx(grid[:, 1], shift, scale, max_len)
return grid
def warp_pos_idx(
pos_idx: torch.Tensor,
shift: Optional[int] = 0,
scale: Optional[str] = None,
max_len: Optional[Union[int, float]]=None
):
'''
pos_idx: the 1-D positional index to be warped (B, D)
shift: the max shift value for the positional indices
scale: the scale scheme for warping positional indices
max_len: the max scale length
'''
if scale != None:
assert isinstance(scale, str) and isinstance(max_len, (int, float))
if scale.lower() == 'linear':
pos_idx = max_len * (pos_idx / pos_idx.max())
elif scale.lower() == 'sqrt':
pos_idx = max_len * torch.sqrt(pos_idx / max_len)
elif scale.lower() in ['sine', 'cosine', 'sin', 'cos']:
pos_idx = max_len * torch.sin(pos_idx / max_len * (torch.pi/2))
else:
raise NotImplementedError('Only support linear, cosine, beta scale scheme for warping')
pos_idx = pos_idx + random.randint(0, shift)
return pos_idx
================================================
FILE: projects/evaluate/adm_evaluator.py
================================================
import argparse
import io
import os
import random
import warnings
import zipfile
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from typing import Iterable, Optional, Tuple
from PIL import Image
import numpy as np
import requests
import tensorflow.compat.v1 as tf
from scipy import linalg
from tqdm.auto import tqdm
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
INCEPTION_V3_PATH = "checkpoints/classify_image_graph_def.pb"
FID_POOL_NAME = "pool_3:0"
FID_SPATIAL_NAME = "mixed_6/conv:0"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("ref_batch", help="path to reference batch npz file")
parser.add_argument("sample_batch", help="path to sample batch npz file")
args = parser.parse_args()
config = tf.ConfigProto(
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
)
config.gpu_options.allow_growth = True
evaluator = Evaluator(tf.Session(config=config))
print("warming up TensorFlow...")
# This will cause TF to print a bunch of verbose stuff now rather
# than after the next print(), to help prevent confusion.
evaluator.warmup()
print("computing reference batch activations...")
ref_acts = evaluator.read_activations(args.ref_batch)
print("computing/reading reference batch statistics...")
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
print("computing sample batch activations...")
sample_acts = evaluator.read_activations(args.sample_batch)
print("computing/reading sample batch statistics...")
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
print("Computing evaluations...")
print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
print("FID:", sample_stats.frechet_distance(ref_stats))
print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
print("Precision:", prec)
print("Recall:", recall)
class InvalidFIDException(Exception):
pass
class FIDStatistics:
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
self.mu = mu
self.sigma = sigma
def frechet_distance(self, other, eps=1e-6):
"""
Compute the Frechet distance between two sets of statistics.
"""
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
mu1, sigma1 = self.mu, self.sigma
mu2, sigma2 = other.mu, other.sigma
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert (
mu1.shape == mu2.shape
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
assert (
sigma1.shape == sigma2.shape
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
diff = mu1 - mu2
# product might be almost singular
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = (
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
% eps
)
warnings.warn(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
class Evaluator:
def __init__(
self,
session,
batch_size=64,
softmax_batch_size=512,
):
self.sess = session
self.batch_size = batch_size
self.softmax_batch_size = softmax_batch_size
self.manifold_estimator = ManifoldEstimator(session)
with self.sess.graph.as_default():
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
self.softmax = _create_softmax_graph(self.softmax_input)
def warmup(self):
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
if npz_path.endswith('.npz'):
with open_npz_array(npz_path, "arr_0") as reader:
return self.compute_activations(reader.read_batches(self.batch_size))
else:
preds = []
spatial_preds = []
files = os.listdir(npz_path)
run_iter = int(len(files) / self.batch_size)
for i in tqdm(range(run_iter)):
samples = []
for file in files[i*self.batch_size: (i+1)*self.batch_size]:
try:
sample_pil = Image.open(os.path.join(npz_path, file))
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
except:
print('wrong file', os.path.join(npz_path, file))
samples = np.stack(samples)
samples = samples.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: samples}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute image features for downstream evals.
:param batches: a iterator over NHWC numpy arrays in [0, 255].
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
dimension. The tuple is (pool_3, spatial).
"""
preds = []
spatial_preds = []
for batch in tqdm(batches):
batch = batch.astype(np.float32)
pred, spatial_pred = self.sess.run(
[self.pool_features, self.spatial_features], {self.image_input: batch}
)
preds.append(pred.reshape([pred.shape[0], -1]))
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
return (
np.concatenate(preds, axis=0),
np.concatenate(spatial_preds, axis=0),
)
def read_statistics(
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
) -> Tuple[FIDStatistics, FIDStatistics]:
if npz_path.endswith('.npz'):
obj = np.load(npz_path)
if "mu" in list(obj.keys()):
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
obj["mu_s"], obj["sigma_s"]
)
return tuple(self.compute_statistics(x) for x in activations)
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
mu = np.mean(activations, axis=0)
sigma = np.cov(activations, rowvar=False)
return FIDStatistics(mu, sigma)
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
softmax_out = []
for i in range(0, len(activations), self.softmax_batch_size):
acts = activations[i : i + self.softmax_batch_size]
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
preds = np.concatenate(softmax_out, axis=0)
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
scores = []
for i in range(0, len(preds), split_size):
part = preds[i : i + split_size]
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return float(np.mean(scores))
def compute_prec_recall(
self, activations_ref: np.ndarray, activations_sample: np.ndarray
) -> Tuple[float, float]:
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
pr = self.manifold_estimator.evaluate_pr(
activations_ref, radii_1, activations_sample, radii_2
)
return (float(pr[0][0]), float(pr[1][0]))
class ManifoldEstimator:
"""
A helper for comparing manifolds of feature vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
"""
def __init__(
self,
session,
row_batch_size=10000,
col_batch_size=10000,
nhood_sizes=(3,),
clamp_to_percentile=None,
eps=1e-5,
):
"""
Estimate the manifold of given feature vectors.
:param session: the TensorFlow session.
:param row_batch_size: row batch size to compute pairwise distances
(parameter to trade-off between memory usage and performance).
:param col_batch_size: column batch size to compute pairwise distances.
:param nhood_sizes: number of neighbors used to estimate the manifold.
:param clamp_to_percentile: prune hyperspheres that have radius larger than
the given percentile.
:param eps: small number for numerical stability.
"""
self.distance_block = DistanceBlock(session)
self.row_batch_size = row_batch_size
self.col_batch_size = col_batch_size
self.nhood_sizes = nhood_sizes
self.num_nhoods = len(nhood_sizes)
self.clamp_to_percentile = clamp_to_percentile
self.eps = eps
def warmup(self):
feats, radii = (
np.zeros([1, 2048], dtype=np.float32),
np.zeros([1, 1], dtype=np.float32),
)
self.evaluate_pr(feats, radii, feats, radii)
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
num_images = len(features)
# Estimate manifold of features by calculating distances to k-NN of each sample.
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
for begin1 in range(0, num_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_images)
row_batch = features[begin1:end1]
for begin2 in range(0, num_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_images)
col_batch = features[begin2:end2]
# Compute distances between batches.
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(row_batch, col_batch)
# Find the k-nearest neighbor from the current batch.
radii[begin1:end1, :] = np.concatenate(
[
x[:, self.nhood_sizes]
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
],
axis=0,
)
if self.clamp_to_percentile is not None:
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
radii[radii > max_distances] = 0
return radii
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
"""
Evaluate if new feature vectors are at the manifold.
"""
num_eval_images = eval_features.shape[0]
num_ref_images = radii.shape[0]
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
for begin1 in range(0, num_eval_images, self.row_batch_size):
end1 = min(begin1 + self.row_batch_size, num_eval_images)
feature_batch = eval_features[begin1:end1]
for begin2 in range(0, num_ref_images, self.col_batch_size):
end2 = min(begin2 + self.col_batch_size, num_ref_images)
ref_batch = features[begin2:end2]
distance_batch[
0 : end1 - begin1, begin2:end2
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
# If a feature vector is inside a hypersphere of some reference sample, then
# the new sample lies at the estimated manifold.
# The radii of the hyperspheres are determined from distances of neighborhood size k.
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
max_realism_score[begin1:end1] = np.max(
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
)
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
return {
"fraction": float(np.mean(batch_predictions)),
"batch_predictions": batch_predictions,
"max_realisim_score": max_realism_score,
"nearest_indices": nearest_indices,
}
def evaluate_pr(
self,
features_1: np.ndarray,
radii_1: np.ndarray,
features_2: np.ndarray,
radii_2: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Evaluate precision and recall efficiently.
:param features_1: [N1 x D] feature vectors for reference batch.
:param radii_1: [N1 x K1] radii for reference vectors.
:param features_2: [N2 x D] feature vectors for the other batch.
:param radii_2: [N x K2] radii for other vectors.
:return: a tuple of arrays for (precision, recall):
- precision: an np.ndarray of length K1
- recall: an np.ndarray of length K2
"""
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool)
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool)
for begin_1 in range(0, len(features_1), self.row_batch_size):
end_1 = begin_1 + self.row_batch_size
batch_1 = features_1[begin_1:end_1]
for begin_2 in range(0, len(features_2), self.col_batch_size):
end_2 = begin_2 + self.col_batch_size
batch_2 = features_2[begin_2:end_2]
batch_1_in, batch_2_in = self.distance_block.less_thans(
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
)
features_1_status[begin_1:end_1] |= batch_1_in
features_2_status[begin_2:end_2] |= batch_2_in
return (
np.mean(features_2_status.astype(np.float64), axis=0),
np.mean(features_1_status.astype(np.float64), axis=0),
)
class DistanceBlock:
"""
Calculate pairwise distances between vectors.
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
"""
def __init__(self, session):
self.session = session
# Initialize TF graph to calculate pairwise distances.
with session.graph.as_default():
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
distance_block_16 = _batch_pairwise_distances(
tf.cast(self._features_batch1, tf.float16),
tf.cast(self._features_batch2, tf.float16),
)
self.distance_block = tf.cond(
tf.reduce_all(tf.math.is_finite(distance_block_16)),
lambda: tf.cast(distance_block_16, tf.float32),
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
)
# Extra logic for less thans.
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
def pairwise_distances(self, U, V):
"""
Evaluate pairwise distances between two batches of feature vectors.
"""
return self.session.run(
self.distance_block,
feed_dict={self._features_batch1: U, self._features_batch2: V},
)
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
return self.session.run(
[self._batch_1_in, self._batch_2_in],
feed_dict={
self._features_batch1: batch_1,
self._features_batch2: batch_2,
self._radii1: radii_1,
self._radii2: radii_2,
},
)
def _batch_pairwise_distances(U, V):
"""
Compute pairwise distances between two batches of feature vectors.
"""
with tf.variable_scope("pairwise_dist_block"):
# Squared norms of each row in U and V.
norm_u = tf.reduce_sum(tf.square(U), 1)
norm_v = tf.reduce_sum(tf.square(V), 1)
# norm_u as a column and norm_v as a row vectors.
norm_u = tf.reshape(norm_u, [-1, 1])
norm_v = tf.reshape(norm_v, [1, -1])
# Pairwise squared Euclidean distances.
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
return D
class NpzArrayReader(ABC):
@abstractmethod
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
pass
@abstractmethod
def remaining(self) -> int:
pass
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
def gen_fn():
while True:
batch = self.read_batch(batch_size)
if batch is None:
break
yield batch
rem = self.remaining()
num_batches = rem // batch_size + int(rem % batch_size != 0)
return BatchIterator(gen_fn, num_batches)
class BatchIterator:
def __init__(self, gen_fn, length):
self.gen_fn = gen_fn
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return self.gen_fn()
class StreamingNpzArrayReader(NpzArrayReader):
def __init__(self, arr_f, shape, dtype):
self.arr_f = arr_f
self.shape = shape
self.dtype = dtype
self.idx = 0
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.shape[0]:
return None
bs = min(batch_size, self.shape[0] - self.idx)
self.idx += bs
if self.dtype.itemsize == 0:
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
read_count = bs * np.prod(self.shape[1:])
read_size = int(read_count * self.dtype.itemsize)
data = _read_bytes(self.arr_f, read_size, "array data")
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
def remaining(self) -> int:
return max(0, self.shape[0] - self.idx)
class MemoryNpzArrayReader(NpzArrayReader):
def __init__(self, arr):
self.arr = arr
self.idx = 0
@classmethod
def load(cls, path: str, arr_name: str):
with open(path, "rb") as f:
arr = np.load(f)[arr_name]
return cls(arr)
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
if self.idx >= self.arr.shape[0]:
return None
res = self.arr[self.idx : self.idx + batch_size]
self.idx += batch_size
return res
def remaining(self) -> int:
return max(0, self.arr.shape[0] - self.idx)
@contextmanager
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
with _open_npy_file(path, arr_name) as arr_f:
version = np.lib.format.read_magic(arr_f)
if version == (1, 0):
header = np.lib.format.read_array_header_1_0(arr_f)
elif version == (2, 0):
header = np.lib.format.read_array_header_2_0(arr_f)
else:
yield MemoryNpzArrayReader.load(path, arr_name)
return
shape, fortran, dtype = header
if fortran or dtype.hasobject:
yield MemoryNpzArrayReader.load(path, arr_name)
else:
yield StreamingNpzArrayReader(arr_f, shape, dtype)
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
data = bytes()
while True:
# io files (default in python3) return None or raise on
# would-block, python2 file will truncate, probably nothing can be
# done about that. note that regular files can't be non-blocking
try:
r = fp.read(size - len(data))
data += r
if len(r) == 0 or len(data) == size:
break
except io.BlockingIOError:
pass
if len(data) != size:
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
return data
@contextmanager
def _open_npy_file(path: str, arr_name: str):
with open(path, "rb") as f:
with zipfile.ZipFile(f, "r") as zip_f:
if f"{arr_name}.npy" not in zip_f.namelist():
raise ValueError(f"missing {arr_name} in npz file")
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
yield arr_f
def _download_inception_model():
if os.path.exists(INCEPTION_V3_PATH):
return
print("downloading InceptionV3 model...")
with requests.get(INCEPTION_V3_URL, stream=True) as r:
r.raise_for_status()
tmp_path = INCEPTION_V3_PATH + ".tmp"
with open(tmp_path, "wb") as f:
for chunk in tqdm(r.iter_content(chunk_size=8192)):
f.write(chunk)
os.rename(tmp_path, INCEPTION_V3_PATH)
def _create_feature_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
pool3, spatial = tf.import_graph_def(
graph_def,
input_map={f"ExpandDims:0": input_batch},
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
name=prefix,
)
_update_shapes(pool3)
spatial = spatial[..., :7]
return pool3, spatial
def _create_softmax_graph(input_batch):
_download_inception_model()
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
with open(INCEPTION_V3_PATH, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
(matmul,) = tf.import_graph_def(
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
)
w = matmul.inputs[1]
logits = tf.matmul(input_batch, w)
return tf.nn.softmax(logits)
def _update_shapes(pool3):
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
ops = pool3.graph.get_operations()
for op in ops:
for o in op.outputs:
shape = o.get_shape()
if shape._dims is not None: # pylint: disable=protected-access
# shape = [s.value for s in shape] TF 1.x
shape = [s for s in shape] # TF 2.x
new_shape = []
for j, s in enumerate(shape):
if s == 1 and j == 0:
new_shape.append(None)
else:
new_shape.append(s)
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
return pool3
def _numpy_partition(arr, kth, **kwargs):
num_workers = min(cpu_count(), len(arr))
chunk_size = len(arr) // num_workers
extra = len(arr) % num_workers
start_idx = 0
batches = []
for i in range(num_workers):
size = chunk_size + (1 if i < extra else 0)
batches.append(arr[start_idx : start_idx + size])
start_idx += size
with ThreadPool(num_workers) as pool:
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
if __name__ == "__main__":
main()
================================================
FILE: projects/preprocess/image_latent_c2i.py
================================================
import os
import torch
import argparse
import datetime
import time
import torchvision
import logging
import math
import shutil
import accelerate
import torch
import torch.utils.checkpoint
import diffusers
import numpy as np
import torch.nn.functional as F
import einops
import json
import os.path as osp
import functools
from PIL import Image
from torch.cuda import amp
from torch.utils.data import DataLoader, Dataset
from omegaconf import OmegaConf
from accelerate import Accelerator, skip_first_batches
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from diffusers import AutoencoderKL, AutoencoderDC
from nit.utils.misc_utils import instantiate_from_config
from torchvision import transforms
from torchvision.datasets.folder import DatasetFolder, default_loader
from torchvision.transforms.functional import hflip
from safetensors.torch import save_file
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from nit.utils.model_utils import dc_ae_encode, sd_vae_encode
logger = get_logger(__name__, log_level="INFO")
# For Omegaconf Tuple
def resolve_tuple(*args):
return tuple(args)
OmegaConf.register_new_resolver("tuple", resolve_tuple)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
# ----General Training Arguments----
parser.add_argument(
"--config",
type=str,
default="",
help="The config file for training.",
)
parser.add_argument(
"--project_dir",
type=str,
default="t2i_linear_attention",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="A seed for reproducible training."
)
args = parser.parse_args()
return args
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
class ImagenetDataDictWrapper(Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, i):
x, y, p = self.dataset[i]
return {"jpg": x, "cls": y, "path": p}
def __len__(self):
return len(self.dataset)
# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
def get_train_sampler(global_batch_size, max_steps, resume_step):
sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long)
return sample_indices[resume_step * global_batch_size : ].tolist()
class ImagenetLoader():
def __init__(self, data_config):
super().__init__()
self.batch_size = data_config.dataloader.batch_size
self.num_workers = data_config.dataloader.num_workers
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, data_config.dataset.resolution)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform))
self.test_dataset = None
self.val_dataset = None
def train_len(self):
return len(self.train_dataset)
def train_dataloader(self, global_batch_size, max_steps, resume_step):
sampler = get_train_sampler(
global_batch_size, max_steps, resume_step
)
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
def test_dataloader(self):
return None
def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
pin_memory=True,
drop_last=False
)
def main(args):
project_dir = args.project_dir
config = OmegaConf.load(args.config)
model_config = config.model
data_config = config.data
train_config = config.training
config_dir = osp.join(project_dir, 'configs')
checkpoint_dir = osp.join(project_dir, 'checkpoints')
logging_dir = osp.join(project_dir, 'logs')
sample_dir = osp.join(project_dir, 'samples')
accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=train_config.gradient_accumulation_steps,
mixed_precision=train_config.mixed_precision,
log_with=train_config.tracker,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
)
# Handle the repository creation
if accelerator.is_main_process:
os.makedirs(project_dir, exist_ok=True)
os.makedirs(config_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(logging_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)
OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml"))
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
diffusers.utils.logging.set_verbosity_info()
else:
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if train_config.allow_tf32: # for A100
torch.backends.cuda.matmul.allow_tf32 = True
# Setup models
weight_dtype = torch.float32
if 'sd-vae' in model_config.vae:
sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)
sd_vae.eval()
sd_vae.requires_grad_(False)
encode_func = functools.partial(sd_vae_encode, sd_vae)
elif 'dc-ae' in model_config.vae:
dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)
dc_ae.eval()
dc_ae.requires_grad_(False)
encode_func = functools.partial(dc_ae_encode, dc_ae)
# Setup Dataloader
total_batch_size = (
data_config.dataloader.batch_size *
accelerator.num_processes *
train_config.gradient_accumulation_steps
)
global_steps = 0
if train_config.resume_from_checkpoint:
# normal read with safety check
if train_config.resume_from_checkpoint != "latest":
resume_from_path = os.path.basename(train_config.resume_from_checkpoint)
else: # Get the most recent checkpoint
dirs = os.listdir(checkpoint_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None
if resume_from_path is None:
logger.info(
f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run."
)
train_config.resume_from_checkpoint = None
else:
global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps
logger.info(f"Resuming from steps: {global_steps}")
get_train_dataloader = ImagenetLoader(data_config)
train_len = get_train_dataloader.train_len()
train_config.max_train_steps = math.ceil(train_len / total_batch_size)
train_dataloader = get_train_dataloader.train_dataloader(
global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps,
)
# Prepare Accelerate
train_dataloader= accelerator.prepare(train_dataloader)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}")
logger.info(f" Dataset Length = {get_train_dataloader.train_len()}")
logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {train_config.max_train_steps}")
# Potentially load in the weights and states from a previous save
if train_config.resume_from_checkpoint and resume_from_path != None:
accelerator.print(f"Resuming from checkpoint {resume_from_path}")
accelerator.load_state(resume_from_path)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, train_config.max_train_steps),
disable = not accelerator.is_local_main_process
)
progress_bar.set_description("Optim Steps")
progress_bar.update(global_steps)
# prepare patch size and max sequence length
# make directory
os.makedirs(data_config.dataset.target_dir, exist_ok=True)
def save_data(z, y, p):
# p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG'
# target_folder: 'target_dir/n01440764'
target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2])
f_name = p.split('/')[-1].split('.')[0]
os.makedirs(target_folder, exist_ok=True)
single_data = dict(latent=z.contiguous(), label=y)
save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors'))
for step, batch in enumerate(train_dataloader, start=global_steps):
for batch_key in batch.keys():
if not isinstance(batch[batch_key], (list, str)):
batch[batch_key] = batch[batch_key].to(dtype=weight_dtype)
x = batch['jpg']
y = batch['cls']
p = batch['path']
if 'sd-vae' in model_config.vae:
z_ori = encode_func(x)
z_hflip = encode_func(hflip(x))
z = torch.stack([z_ori, z_hflip], dim=1)
elif 'dc-ae' in model_config.vae:
z_ori = encode_func(x)
z_hflip = encode_func(hflip(x))
z = torch.stack([z_ori, z_hflip], dim=1)
for i in range(len(p)):
save_data(z[i], y[i], p[i])
# Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation
if accelerator.sync_gradients:
progress_bar.update(1)
global_steps += 1
if global_steps % train_config.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if train_config.checkpoints_total_limit is not None:
checkpoints = os.listdir(checkpoint_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= train_config.checkpoints_total_limit:
num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(checkpoint_dir, f"checkpoint-{global_steps}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
accelerator.wait_for_everyone()
if global_steps >= train_config.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
================================================
FILE: projects/preprocess/image_nr_latent_c2i.py
================================================
import os
import torch
import argparse
import datetime
import time
import torchvision
import logging
import math
import shutil
import accelerate
import torch
import torch.utils.checkpoint
import diffusers
import numpy as np
import torch.nn.functional as F
import einops
import json
import os.path as osp
import functools
from PIL import Image
from torch.cuda import amp
from torch.utils.data import DataLoader, Dataset
from omegaconf import OmegaConf
from accelerate import Accelerator, skip_first_batches
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from diffusers import AutoencoderKL, AutoencoderDC
from nit.utils.misc_utils import instantiate_from_config
from torchvision import transforms
from torchvision.datasets.folder import DatasetFolder, default_loader
from torchvision.transforms.functional import hflip
from safetensors.torch import save_file
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from nit.utils.model_utils import dc_ae_encode, sd_vae_encode
logger = get_logger(__name__, log_level="INFO")
# For Omegaconf Tuple
def resolve_tuple(*args):
return tuple(args)
OmegaConf.register_new_resolver("tuple", resolve_tuple)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
# ----General Training Arguments----
parser.add_argument(
"--config",
type=str,
default="",
help="The config file for training.",
)
parser.add_argument(
"--project_dir",
type=str,
default="t2i_linear_attention",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="A seed for reproducible training."
)
args = parser.parse_args()
return args
def native_resolution_resize(pil_image, min_image_size, max_image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
w, h = pil_image.size
if w * h < max_image_size**2:
new_w = max(1, int(w/min_image_size)) * min_image_size
new_h = max(1, int(h/min_image_size)) * min_image_size
else:
new_w = np.sqrt(w/h) * max_image_size
new_h = new_w * h / w
new_w = int(new_w/min_image_size) * min_image_size
new_h = int(new_h/min_image_size) * min_image_size
pil_image = pil_image.resize((new_w, new_h), resample=Image.Resampling.BICUBIC)
return pil_image
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
class ImagenetDataDictWrapper(Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, i):
x, y, p = self.dataset[i]
return {"jpg": x, "cls": y, "path": p}
def __len__(self):
return len(self.dataset)
# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
def get_train_sampler(global_batch_size, max_steps, resume_step):
sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long)
return sample_indices[resume_step * global_batch_size : ].tolist()
class ImagenetLoader():
def __init__(self, data_config):
super().__init__()
self.batch_size = data_config.dataloader.batch_size
self.num_workers = data_config.dataloader.num_workers
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: native_resolution_resize(
pil_image, data_config.dataset.min_image_size, data_config.dataset.max_image_size
)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform))
self.test_dataset = None
self.val_dataset = None
def train_len(self):
return len(self.train_dataset)
def train_dataloader(self, global_batch_size, max_steps, resume_step):
sampler = get_train_sampler(
global_batch_size, max_steps, resume_step
)
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=False
)
def test_dataloader(self):
return None
def val_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True
)
def main(args):
project_dir = args.project_dir
config = OmegaConf.load(args.config)
model_config = config.model
data_config = config.data
train_config = config.training
config_dir = osp.join(project_dir, 'configs')
checkpoint_dir = osp.join(project_dir, 'checkpoints')
logging_dir = osp.join(project_dir, 'logs')
sample_dir = osp.join(project_dir, 'samples')
accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=train_config.gradient_accumulation_steps,
mixed_precision=train_config.mixed_precision,
log_with=train_config.tracker,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
)
# Handle the repository creation
if accelerator.is_main_process:
os.makedirs(project_dir, exist_ok=True)
os.makedirs(config_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(logging_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)
OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml"))
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
diffusers.utils.logging.set_verbosity_info()
else:
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if train_config.allow_tf32: # for A100
torch.backends.cuda.matmul.allow_tf32 = True
# Setup models
weight_dtype = torch.float32
if 'sd-vae' in model_config.vae:
sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)
sd_vae.eval()
sd_vae.requires_grad_(False)
encode_func = functools.partial(sd_vae_encode, sd_vae)
elif 'dc-ae' in model_config.vae:
dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)
dc_ae.eval()
dc_ae.requires_grad_(False)
encode_func = functools.partial(dc_ae_encode, dc_ae)
# Setup Dataloader
total_batch_size = (
data_config.dataloader.batch_size *
accelerator.num_processes *
train_config.gradient_accumulation_steps
)
global_steps = 0
if train_config.resume_from_checkpoint:
# normal read with safety check
if train_config.resume_from_checkpoint != "latest":
resume_from_path = os.path.basename(train_config.resume_from_checkpoint)
else: # Get the most recent checkpoint
dirs = os.listdir(checkpoint_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None
if resume_from_path is None:
logger.info(
f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run."
)
train_config.resume_from_checkpoint = None
else:
global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps
logger.info(f"Resuming from steps: {global_steps}")
get_train_dataloader = ImagenetLoader(data_config)
train_len = get_train_dataloader.train_len()
train_config.max_train_steps = math.ceil(train_len / total_batch_size)
train_dataloader = get_train_dataloader.train_dataloader(
global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps,
)
# Prepare Accelerate
train_dataloader= accelerator.prepare(train_dataloader)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}")
logger.info(f" Dataset Length = {get_train_dataloader.train_len()}")
logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {train_config.max_train_steps}")
# Potentially load in the weights and states from a previous save
if train_config.resume_from_checkpoint and resume_from_path != None:
accelerator.print(f"Resuming from checkpoint {resume_from_path}")
accelerator.load_state(resume_from_path)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, train_config.max_train_steps),
disable = not accelerator.is_local_main_process
)
progress_bar.set_description("Optim Steps")
progress_bar.update(global_steps)
# prepare patch size and max sequence length
# make directory
os.makedirs(data_config.dataset.target_dir, exist_ok=True)
def save_data(z, y, p):
# p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG'
# target_folder: 'target_dir/n01440764'
target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2])
f_name = p.split('/')[-1].split('.')[0]
os.makedirs(target_folder, exist_ok=True)
single_data = dict(latent=z.contiguous(), label=y)
save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors'))
for step, batch in enumerate(train_dataloader, start=global_steps):
for batch_key in batch.keys():
if not isinstance(batch[batch_key], (list, str)):
batch[batch_key] = batch[batch_key].to(dtype=weight_dtype)
x = batch['jpg']
y = batch['cls']
p = batch['path']
if 'sd-vae' in model_config.vae:
z_ori = encode_func(x)
z_hflip = encode_func(hflip(x))
z = torch.stack([z_ori, z_hflip], dim=1)
elif 'dc-ae' in model_config.vae:
z_ori = encode_func(x)
z_hflip = encode_func(hflip(x))
z = torch.stack([z_ori, z_hflip], dim=1)
for i in range(len(p)):
save_data(z[i], y[i], p[i])
# Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation
if accelerator.sync_gradients:
progress_bar.update(1)
global_steps += 1
if global_steps % train_config.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if train_config.checkpoints_total_limit is not None:
checkpoints = os.listdir(checkpoint_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= train_config.checkpoints_total_limit:
num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(checkpoint_dir, f"checkpoint-{global_steps}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
accelerator.wait_for_everyone()
if global_steps >= train_config.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
================================================
FILE: projects/sample/sample_c2i_ddp.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Samples a large number of images from a pre-trained SiT model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations
For a simple single-GPU/CPU sampling script, see sample.py.
"""
import torch
import torch.distributed as dist
from diffusers.models import AutoencoderKL, AutoencoderDC
from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import math
import functools
import argparse
from omegaconf import OmegaConf
from einops import rearrange
from nit.schedulers.flow_matching.samplers_c2i import euler_sampler, euler_maruyama_sampler
from nit.utils import init_from_ckpt
from nit.utils.misc_utils import instantiate_from_config
from nit.utils.model_utils import sd_vae_decode, dc_ae_decode
def create_npz_from_sample_folder(sample_dir, num=50_000):
"""
Builds a single .npz file from a folder of .png samples.
"""
samples = []
for i in tqdm(range(num), desc="Building .npz file from samples"):
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
sample_np = np.asarray(sample_pil).astype(np.uint8)
samples.append(sample_np)
samples = np.stack(samples)
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
npz_path = f"{sample_dir}.npz"
np.savez(npz_path, arr_0=samples)
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
return npz_path
def main(args):
"""
Run sampling.
"""
torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
torch.set_grad_enabled(False)
# Setup DDP:cd
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
# setup dtype
dtype = torch.bfloat16
# Load model:
config = OmegaConf.load(args.config)
model_config = config.model
if 'dc-ae' in model_config.vae_dir:
dc_ae = AutoencoderDC.from_pretrained(model_config.vae_dir).to(device)
if args.slice_vae:
dc_ae.enable_slicing()
if args.slice_vae:
dc_ae.enable_slicing()
spatial_downsample = 32
decode_func = functools.partial(dc_ae_decode, dc_ae)
elif 'sd-vae' in model_config.vae_dir:
sd_vae = AutoencoderKL.from_pretrained(model_config.vae_dir).to(device)
if args.slice_vae:
sd_vae.enable_slicing()
if args.slice_vae:
sd_vae.enable_slicing()
spatial_downsample = 8
decode_func = functools.partial(sd_vae_decode, sd_vae)
else: raise
assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
# image resolution
patch_size = int(model_config.network.params.patch_size)
latent_h = int(args.height / spatial_downsample / patch_size)
latent_w = int(args.width / spatial_downsample / patch_size)
if args.interpolation != 'no':
model_config.network.params['custom_freqs'] = args.interpolation
model_config.network.params['max_pe_len_h'] = latent_h
model_config.network.params['max_pe_len_w'] = latent_w
model_config.network.params['decouple'] = args.decouple
model_config.network.params['ori_max_pe_len'] = int(args.ori_max_pe_len)
model = instantiate_from_config(model_config.network).to(device=device, dtype=dtype)
init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True)
model.eval() # important!
if args.ag_config != None and args.ag_ckpt != None:
ag_config = OmegaConf.load(args.ag_config)
ag_model_config = ag_config.model
ag_model = instantiate_from_config(ag_model_config.network).to(device=device, dtype=dtype)
init_from_ckpt(ag_model, checkpoint_dir=args.ag_ckpt, ignore_keys=None, verbose=True)
ag_model.eval() # important!
else:
ag_model = None
# Create folder to save samples:
train_iter = args.ckpt.split('/')[-2].split('-')[-1]
folder_name = f"{train_iter}-{args.height}x{args.width}-{args.mode}-{args.num_steps}-" \
f"cfg-{args.cfg_scale}-low-{args.guidance_low}-high-{args.guidance_high}"
if ag_model != None:
sample_folder_dir = f"{args.sample_dir}/ag-{folder_name}"
else:
sample_folder_dir = f"{args.sample_dir}/{folder_name}"
if args.interpolation != 'no':
sample_folder_dir += f'-{args.interpolation}'
if rank == 0:
os.makedirs(sample_folder_dir, exist_ok=True)
print(f"Saving .png samples at {sample_folder_dir}")
dist.barrier()
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
n = args.per_proc_batch_size
global_batch_size = n * dist.get_world_size()
# To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
if rank == 0:
print(f"Total number of images that will be sampled: {total_samples}")
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size"
samples_needed_this_gpu = int(total_samples // dist.get_world_size())
assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size"
iterations = int(samples_needed_this_gpu // n)
pbar = range(iterations)
pbar = tqdm(pbar) if rank == 0 else pbar
total = 0
for i in pbar:
# Sample inputs:
z = torch.randn(
(n*latent_h*latent_w, model.in_channels, patch_size, patch_size),
device=device, dtype=dtype
)
y = torch.randint(0, args.num_classes, (n,), device=device)
hw_list = torch.tensor([[latent_h, latent_w] for _ in range(n)], device=device, dtype=torch.int)
seqlens = hw_list[:, 0] * hw_list[:, 1]
cu_seqlens = torch.cat([
torch.tensor([0], device=hw_list.device, dtype=torch.int32),
torch.cumsum(seqlens, dim=0, dtype=torch.int32)
])
can_pass = True
for j in range(n):
index = j * dist.get_world_size() + rank + total
if not os.path.exists(f"{sample_folder_dir}/{index:06d}.png"):
can_pass = False
if can_pass:
total += global_batch_size
print('total: ', total)
continue
# Sample images:
sampling_kwargs = dict(
model=model,
ag_model=ag_model,
latents=z,
y=y,
hw_list=hw_list,
num_steps=args.num_steps,
heun=args.heun,
cfg_scale=args.cfg_scale,
guidance_low=args.guidance_low,
guidance_high=args.guidance_high,
path_type=args.path_type,
)
with torch.no_grad():
if args.mode == "sde":
samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)
elif args.mode == "ode":
samples = euler_sampler(**sampling_kwargs).to(torch.float32)
else:
raise NotImplementedError
samples = rearrange(samples, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=latent_h, w=latent_w)
samples = decode_func(samples)
samples = (samples + 1) / 2.
samples = torch.clamp(
255. * samples, 0, 255
).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
# Save samples to disk as individual .png files
for i, sample in enumerate(samples):
index = i * dist.get_world_size() + rank + total
Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
total += global_batch_size
# Make sure all processes have finished saving their samples before attempting to convert to .npz
dist.barrier()
if rank == 0:
# create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
print("Done.")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# seed
parser.add_argument("--global-seed", type=int, default=0)
# precision
parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True,
help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.")
# logging/saving:
parser.add_argument("--config", type=str, default=None, help="Optional config to a SiT checkpoint.")
parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.")
parser.add_argument("--sample-dir", type=str, default="workdir/c2i/samples")
parser.add_argument("--ag-config", type=str, default=None)
parser.add_argument("--ag-ckpt", type=str, default=None)
# model
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--slice_vae", action=argparse.BooleanOptionalAction, default=False) # only for ode
# number of samples
parser.add_argument("--per-proc-batch-size", type=int, default=32)
parser.add_argument("--num-fid-samples", type=int, default=50_000)
# sampling related hyperparameters
parser.add_argument("--mode", type=str, default="ode")
parser.add_argument("--cfg-scale", type=float, default=1.5)
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
parser.add_argument("--num-steps", type=int, default=50)
parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode
parser.add_argument("--guidance-low", type=float, default=0.)
parser.add_argument("--guidance-high", type=float, default=1.)
parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'ntk-aware', 'ntk-by-parts', 'yarn', 'ntk-aware-pro1', 'ntk-aware-pro2', 'scale1', 'scale2'], default='no') # interpolation
parser.add_argument("--ori-max-pe-len", default=None, type=int)
parser.add_argument("--decouple", default=False, action="store_true") # interpolation
# will be deprecated
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode
args = parser.parse_args()
main(args)
================================================
FILE: projects/train/packed_trainer_c2i.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import copy
import functools
import gc
import itertools
import json
import logging
import math
import os
import time
import random
import shutil
import importlib
import csv
import numpy as np
import os.path as osp
from pathlib import Path
from typing import List, Union
from packaging import version
from tqdm.auto import tqdm
from copy import deepcopy
from omegaconf import OmegaConf
from einops import rearrange, repeat
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision.transforms.functional as TF
from torch.utils.data import default_collate, Dataset
from torchvision import transforms
from torchvision.transforms import Normalize
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
import transformers
import diffusers
from diffusers.optimization import get_scheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers import AutoencoderKL
from timeit import default_timer as timer
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from nit.schedulers.flow_matching.loss import FlowMatchingLoss
from nit.data.packed_c2i_data import C2ILoader
from nit.utils.misc_utils import (
get_obj_from_str, get_dtype, instantiate_from_config
)
from nit.utils.train_utils import (
update_ema, log_validation,
)
from nit.utils.gpu_memory_monitor import build_gpu_memory_monitor
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
# ----General Training Arguments----
parser.add_argument(
"--config",
type=str,
default="",
help="The config file for training.",
)
parser.add_argument(
"--project_dir",
type=str,
default="t2i_linear_attention",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="A seed for reproducible training."
)
args = parser.parse_args()
return args
def main(args):
project_dir = args.project_dir
config = OmegaConf.load(args.config)
model_config = config.model
data_config = config.data
train_config = config.training
config_dir = osp.join(project_dir, 'configs')
checkpoint_dir = osp.join(project_dir, 'checkpoints')
logging_dir = osp.join(project_dir, 'logs')
sample_dir = osp.join(project_dir, 'samples')
if getattr(train_config, 'fsdp_config', None) != None:
import functools
from torch.distributed.fsdp.fully_sharded_data_parallel import (
BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision,
StateDictType, FullStateDictConfig, FullOptimStateDictConfig,
)
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
fsdp_cfg = train_config.fsdp_config
if train_config.mixed_precision == "fp16":
dtype = torch.float16
elif train_config.mixed_precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy = {
'FULL_SHARD': ShardingStrategy.FULL_SHARD,
'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
'NO_SHARD': ShardingStrategy.NO_SHARD,
'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,
'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,
}[fsdp_cfg.sharding_strategy],
backward_prefetch = {
'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
}[fsdp_cfg.backward_prefetch],
mixed_precision_policy = MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
),
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params
),
cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload),
state_dict_type = {
'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT,
'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT,
'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT
}[fsdp_cfg.state_dict_type],
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
limit_all_gathers = fsdp_cfg.limit_all_gathers,
use_orig_params = fsdp_cfg.use_orig_params,
sync_module_states = fsdp_cfg.sync_module_states,
forward_prefetch = fsdp_cfg.forward_prefetch,
activation_checkpointing = fsdp_cfg.activation_checkpointing,
)
else:
fsdp_plugin = None
accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=train_config.gradient_accumulation_steps,
mixed_precision=train_config.mixed_precision,
log_with=train_config.tracker,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
fsdp_plugin=fsdp_plugin,
)
# Handle the repository creation
if accelerator.is_main_process:
os.makedirs(project_dir, exist_ok=True)
os.makedirs(config_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(logging_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)
OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml"))
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
if train_config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
total_batch_size = (
data_config.dataloader.batch_size *
accelerator.num_processes *
train_config.gradient_accumulation_steps
)
if train_config.scale_lr:
learning_rate = (
train_config.learning_rate *
total_batch_size / train_config.learning_rate_base_batch_size
)
else:
learning_rate = train_config.learning_rate
# prepare model, dataloader, optimizer and scheduler
model = instantiate_from_config(model_config.network).to(device=accelerator.device)
model.train()
if model_config.use_ema:
ema_model = deepcopy(model)
ema_model.train()
ema_model.requires_grad_(False)
# Handle mixed precision and device placement
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(model).dtype != torch.float32:
raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}"
)
if accelerator.is_main_process:
total_params = 0
trainable_params = 0
projector_params = 0
for name, param in model.named_parameters():
print(name, param.requires_grad)
total_params += param.numel() # Total number of elements in the parameter
if param.requires_grad: # Check if the parameter is trainable
trainable_params += param.numel()
if 'projector' in name:
projector_params += param.numel()
print(trainable_params, total_params, total_params-projector_params, trainable_params/total_params)
# Optimizer creation
target_optimizer = train_config.optimizer.get('target', 'torch.optim.AdamW')
optimizer = get_obj_from_str(target_optimizer)(
model.parameters(), lr=learning_rate,
**train_config.optimizer.get("params", dict())
)
# Dataset creation and data processing
# Here, we compute not just the text embeddings but also the additional embeddings
global_steps = 0
if train_config.resume_from_checkpoint:
# normal read with safety check
if train_config.resume_from_checkpoint != "latest":
resume_from_path = os.path.basename(train_config.resume_from_checkpoint)
else: # Get the most recent checkpoint
dirs = os.listdir(checkpoint_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None
if resume_from_path is None:
logger.info(
f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run."
)
train_config.resume_from_checkpoint = None
else:
global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps
logger.info(f"Resuming from steps: {global_steps}")
get_train_dataloader = C2ILoader(data_config)
train_dataloader = get_train_dataloader.train_dataloader(
rank=accelerator.process_index, world_size=accelerator.num_processes,
global_batch_size=total_batch_size, max_steps=train_config.max_train_steps,
resume_steps=global_steps, seed=args.seed
)
# LR Scheduler creation
# Scheduler and math around the number of training steps.
lr_scheduler = get_scheduler(
train_config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=train_config.lr_warmup_steps,
num_training_steps=train_config.max_train_steps,
)
# Prepare for training
# Prepare everything with our `accelerator`.
if model_config.use_ema:
ema_model, model, optimizer, lr_scheduler = accelerator.prepare(
ema_model, model, optimizer, lr_scheduler
)
else:
model, optimizer, lr_scheduler = accelerator.prepare(
model, optimizer, lr_scheduler
)
# transport
loss_fn = FlowMatchingLoss(**OmegaConf.to_container(model_config.transport))
if model_config.enc_type == 'radio':
from nit.models.nvidia_radio.hubconf import radio_model
encoder = radio_model(version=model_config.enc_dir, progress=True, support_packing=True)
encoder.to(device=accelerator.device).eval()
encoder.requires_grad_(False)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process and getattr(train_config, 'tracker', 'wandb') != None:
tracker_project_name = project_dir.split('/')[-1]
# accelerator.init_trackers("mcga", config=config, init_kwargs=train_config.tracker_kwargs)
accelerator.init_trackers(tracker_project_name, config=config, init_kwargs=train_config.tracker_kwargs)
# initialize GPU memory monitor before applying parallelisms to the model
gpu_memory_monitor = build_gpu_memory_monitor(logger)
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
# 15. Train!
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}")
logger.info(f" Dataset Length = {get_train_dataloader.train_len()}")
logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {train_config.max_train_steps}")
logger.info(
" GPU memory usage for model: "
f"{gpu_mem_stats.max_reserved_gib:.2f}GiB"
f"({gpu_mem_stats.max_reserved_pct:.2f}%)"
)
gpu_memory_monitor.reset_peak_stats()
data_loading_times = []
feat_enc_times = []
# Potentially load in the weights and states from a previous save
if train_config.resume_from_checkpoint and resume_from_path != None:
accelerator.print(f"Resuming from checkpoint {resume_from_path}")
accelerator.load_state(resume_from_path)
progress_bar = tqdm(
range(0, train_config.max_train_steps),
initial=global_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_main_process,
)
for batch in train_dataloader:
time_last_log = timer()
data_load_start = timer()
# load dataset from batch
batch_image = [image.to(accelerator.device) for image in batch['image']]
batch_label = batch['label'].squeeze(0).to(accelerator.device, torch.int)
packed_latent = batch['latent'].squeeze(0).to(accelerator.device)
noises = torch.randn_like(packed_latent)
hw_list = batch['hw_list'].squeeze(0).to(torch.int)
batch_size = hw_list.shape[0]
dropout_prob = model_config.network.params.class_dropout_prob
num_classes = model_config.network.params.num_classes
if dropout_prob > 0:
drop_ids = torch.rand(batch_label.shape[0], device=accelerator.device) < dropout_prob
batch_label = torch.where(drop_ids, num_classes, batch_label)
data_loading_times.append(timer() - data_load_start)
feat_enc_start = timer()
zs = []
if model_config.enc_type == 'radio':
with torch.no_grad(), accelerator.autocast():
raw_images = [(image.unsqueeze(0)+1.0)/2.0 for image in batch_image]
_, z = encoder.forward_pack(raw_images)
zs.append(z)
feat_enc_times.append(timer() - feat_enc_start)
with accelerator.accumulate(model):
# forward and calculate loss
model_kwargs = dict(y=batch_label, hw_list=hw_list)
fm_loss, proj_loss = loss_fn(model, batch_size, packed_latent, noises, model_kwargs, use_dir_loss=True, zs=zs)
loss = fm_loss + model_config.proj_coeff * proj_loss
accelerator.backward(loss)
if accelerator.sync_gradients and train_config.max_grad_norm > 0:
all_norm = accelerator.clip_grad_norm_(model.parameters(), train_config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 20.4.15. Make EMA update to target student model parameters
if model_config.use_ema:
update_ema(ema_model, model, model_config.ema_decay)
global_steps += 1
time_delta = timer() - time_last_log
sps = batch_size / time_delta
time_data_loading = np.mean(data_loading_times)
time_feat_enc = np.mean(feat_enc_times)
time_data_loading_pct = 100 * time_data_loading / time_delta
time_feat_enc_pct = 100 * time_feat_enc / time_delta
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
if global_steps % train_config.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if accelerator.is_main_process and train_config.checkpoints_total_limit is not None:
checkpoints = os.listdir(checkpoint_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= train_config.checkpoints_total_limit:
num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = osp.join(checkpoint_dir, removing_checkpoint)
try:
shutil.rmtree(removing_checkpoint)
except:
pass
save_path = osp.join(checkpoint_dir, f"checkpoint-{global_steps}")
if accelerator.is_main_process:
os.makedirs(save_path, exist_ok=True)
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if global_steps in train_config.checkpoint_list:
save_path = os.path.join(checkpoint_dir, f"save-checkpoint-{global_steps}")
if accelerator.is_main_process:
os.makedirs(save_path, exist_ok=True)
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
time.sleep(10)
torch.cuda.empty_cache()
if global_steps % train_config.validation_steps == 0:
log_validation(model)
logs = {
# loss and lr
"loss_denoising": fm_loss.detach().item(),
"loss_projector": proj_loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
# time and status
"sps": sps,
"data_loading(s)": time_data_loading,
"data_loading(%)": time_data_loading_pct,
"time_feat_enc(s)": time_feat_enc,
"time_feat_enc(%)": time_feat_enc_pct,
"memory_max_active(GiB)": gpu_mem_stats.max_active_gib,
"memory_max_active(%)": gpu_mem_stats.max_active_pct,
"memory_max_reserved(GiB)": gpu_mem_stats.max_reserved_gib,
"memory_max_reserved(%)": gpu_mem_stats.max_reserved_pct,
"memory_num_alloc_retries": gpu_mem_stats.num_alloc_retries,
"memory_num_ooms": gpu_mem_stats.num_ooms
}
if accelerator.sync_gradients and train_config.max_grad_norm > 0:
logs.update({'grad_norm': all_norm.item()})
progress_bar.set_postfix(**logs)
progress_bar.update(1)
accelerator.log(logs, step=global_steps)
if global_steps >= train_config.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
================================================
FILE: requirements.txt
================================================
diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested
transformers>=4.44.2 # The development team is working on version 4.44.2
accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested
sentencepiece>=0.2.0 # T5 used
numpy==1.26.0
streamlit>=1.38.0 # For streamlit web demo
imageio==2.34.2 # For diffusers inference export video
imageio-ffmpeg==0.5.1 # For diffusers inference export video
moviepy==1.0.3 # For export video
pillow==9.5.0
timm
safetensors
einops
triton
torchdiffeq
================================================
FILE: scripts/preprocess/preorocess_in1k_256x256.sh
================================================
NNODES=1
GPUS_PER_NODE=8
MASTER_ADDR="localhost"
export MASTER_PORT=$((30000 + $RANDOM % 21000))
CMD=" \
projects/preprocess/image_latent_c2i.py \
--config configs/preprocess/imagenet1k_256x256.yaml \
--project_dir workdir/preprocess/imagenet1k_256x256 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/preprocess/preorocess_in1k_512x512.sh
================================================
NNODES=1
GPUS_PER_NODE=8
MASTER_ADDR="localhost"
export MASTER_PORT=$((30000 + $RANDOM % 21000))
CMD=" \
projects/preprocess/image_latent_c2i.py \
--config configs/preprocess/imagenet1k_512x512.yaml \
--project_dir workdir/preprocess/imagenet1k_512x512 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/preprocess/preorocess_in1k_native_resolution.sh
================================================
NNODES=1
GPUS_PER_NODE=8
MASTER_ADDR="localhost"
export MASTER_PORT=$((30000 + $RANDOM % 21000))
CMD=" \
projects/preprocess/image_nr_latent_c2i.py \
--config configs/preprocess/imagenet1k_native_resolution.yaml \
--project_dir workdir/preprocess/imagenet1k_native_resolution \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/sample/sample_256x256.sh
================================================
torchrun \
--nnodes 1 \
--nproc_per_node 8 \
projects/sample/sample_c2i_ddp.py \
--config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \
--ckpt checkpoints/nit_xl_model_1000K.safetensors \
--sample-dir ./samples \
--height 256 \
--width 256 \
--per-proc-batch-size 32 \
--mode sde \
--num-steps 250 \
--cfg-scale 2.25 \
--guidance-low 0.0 \
--guidance-high 0.7 \
--slice_vae \
================================================
FILE: scripts/sample/sample_512x512.sh
================================================
torchrun \
--nnodes 1 \
--nproc_per_node 8 \
projects/sample/sample_c2i_ddp.py \
--config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \
--ckpt checkpoints/nit_xl_model_1000K.safetensors \
--sample-dir ./samples \
--height 512 \
--width 512 \
--per-proc-batch-size 32 \
--mode sde \
--num-steps 250 \
--cfg-scale 2.05 \
--guidance-low 0.0 \
--guidance-high 0.7 \
--slice_vae \
================================================
FILE: scripts/sample/sample_768x768.sh
================================================
torchrun \
--nnodes 1 \
--nproc_per_node 8 \
projects/sample/sample_c2i_ddp.py \
--config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \
--ckpt checkpoints/nit_xl_model_1000K.safetensors \
--sample-dir ./samples \
--height 768 \
--width 768 \
--per-proc-batch-size 32 \
--mode ode \
--num-steps 50 \
--cfg-scale 3.0 \
--guidance-low 0.0 \
--guidance-high 0.7 \
--slice_vae \
================================================
FILE: scripts/train/train_b_model.sh
================================================
NNODES=1
GPUS_PER_NODE=2
MASTER_ADDR="localhost"
export MASTER_PORT=60563
mkdir -p workdir/c2i/nit_b_pack_merge_radio_65536
CMD=" \
projects/train/packed_trainer_c2i.py \
--config configs/c2i/nit_b_pack_merge_radio_65536.yaml \
--project_dir workdir/c2i/nit_b_pack_merge_radio_65536 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/train/train_l_model.sh
================================================
NNODES=1
GPUS_PER_NODE=2
MASTER_ADDR="localhost"
export MASTER_PORT=60563
mkdir -p workdir/c2i/nit_l_pack_merge_radio_16384
CMD=" \
projects/train/packed_trainer_c2i.py \
--config configs/c2i/nit_l_pack_merge_radio_16384.yaml \
--project_dir workdir/c2i/nit_l_pack_merge_radio_16384 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/train/train_s_model.sh
================================================
NNODES=1
GPUS_PER_NODE=2
MASTER_ADDR="localhost"
export MASTER_PORT=60563
mkdir -p workdir/c2i/nit_s_pack_merge_radio_65536
CMD=" \
projects/train/packed_trainer_c2i.py \
--config configs/c2i/nit_s_pack_merge_radio_65536.yaml \
--project_dir workdir/c2i/nit_s_pack_merge_radio_65536 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/train/train_xl_model.sh
================================================
NNODES=1
GPUS_PER_NODE=8
MASTER_ADDR="localhost"
export MASTER_PORT=60563
mkdir -p workdir/c2i/nit_xl_pack_merge_radio_16384
CMD=" \
projects/train/packed_trainer_c2i.py \
--config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \
--project_dir workdir/c2i/nit_xl_pack_merge_radio_16384 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: scripts/train/train_xxl_model.sh
================================================
NNODES=1
GPUS_PER_NODE=8
MASTER_ADDR="localhost"
export MASTER_PORT=60563
mkdir -p workdir/c2i/nit_xxl_pack_merge_radio_8192
CMD=" \
projects/train/packed_trainer_c2i.py \
--config configs/c2i/nit_xxl_pack_merge_radio_8192.yaml \
--project_dir workdir/c2i/nit_xxl_pack_merge_radio_8192 \
--seed 0 \
"
TORCHLAUNCHER="torchrun \
--nnodes $NNODES \
--nproc_per_node $GPUS_PER_NODE \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
"
bash -c "$TORCHLAUNCHER $CMD"
================================================
FILE: setup.py
================================================
from setuptools import find_packages, setup
setup(
name='nit',
version='0.0.1',
description='',
packages=find_packages(),
install_requires=[
'torch',
'numpy',
],
)
================================================
FILE: tools/download_dataset_256x256.sh
================================================
target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256"
mkdir -p $target_dir
base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-256x256"
files=(
"n01440764_n02097298.zip"
"n02097474_n02667093.zip"
"n02669723_n03530642.zip"
"n03532672_n04239074.zip"
"n04243546_n15075141.zip"
)
for file in "${files[@]}"; do
echo "download $file ..."
wget -c "$base_url/$file" -O "$target_dir/$file"
echo "download $file finished"
echo "start unzip $file ..."
unzip "$target_dir/$file" -d "$target_dir"
echo "unzip $file finished"
rm "$target_dir/$file"
echo
done
echo "Successfully download all the sampler-meta"
================================================
FILE: tools/download_dataset_512x512.sh
================================================
target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512"
mkdir -p $target_dir
base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-512x512"
files=(
"n01440764_n01697457.zip"
"n01698640_n01855672.zip"
"n01860187_n02074367.zip"
"n02077923_n02097298.zip"
"n02097474_n02110063.zip"
"n02110185_n02138441.zip"
"n02165105_n02415577.zip"
"n02417914_n02667093.zip"
"n02669723_n02859443.zip"
"n02860847_n03041632.zip"
"n03042490_n03291819.zip"
"n03297495_n03530642.zip"
"n03532672_n03743016.zip"
"n03759954_n03884397.zip"
"n03887697_n04033901.zip"
"n04033995_n04239074.zip"
"n04243546_n04398044.zip"
"n04399382_n04560804.zip"
"n04562935_n07745940.zip"
"n07747607_n15075141.zip"
)
for file in "${files[@]}"; do
echo "download $file ..."
wget -c "$base_url/$file" -O "$target_dir/$file"
echo "download $file finished"
echo "start unzip $file ..."
unzip "$target_dir/$file" -d "$target_dir"
echo "unzip $file finished"
rm "$target_dir/$file"
echo
done
echo "Successfully download all the sampler-meta"
================================================
FILE: tools/download_dataset_data_meta.sh
================================================
target_dir="datasets/imagenet1k/data_meta"
mkdir -p $target_dir
base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/data_meta"
files=(
"dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl"
"dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl"
"dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl"
"dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl"
)
for file in "${files[@]}"; do
echo "download $file ..."
wget -c "$base_url/$file" -O "$target_dir/$file"
echo "download $file finished"
echo
done
echo "Successfully download all the data-meta"
================================================
FILE: tools/download_dataset_native_resolution.sh
================================================
target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution"
mkdir -p $target_dir
base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-native-resolution"
files=(
"n01440764_n01855672.zip"
"n01860187_n02097298.zip"
"n02097474_n02138441.zip"
"n02165105_n02667093.zip"
"n02669723_n03041632.zip"
"n03042490_n03530642.zip"
"n03532672_n03884397.zip"
"n03887697_n04239074.zip"
"n04243546_n04560804.zip"
"n04562935_n15075141.zip"
)
for file in "${files[@]}"; do
echo "download $file ..."
wget -c "$base_url/$file" -O "$target_dir/$file"
echo "download $file finished"
echo "start unzip $file ..."
unzip "$target_dir/$file" -d "$target_dir"
echo "unzip $file finished"
rm "$target_dir/$file"
echo
done
echo "Successfully download all the sampler-meta"
================================================
FILE: tools/download_dataset_sampler_meta.sh
================================================
target_dir="datasets/imagenet1k/sampler_meta"
mkdir -p $target_dir
base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/sampler_meta"
files=(
"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json"
"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json"
"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json"
"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json"
)
for file in "${files[@]}"; do
echo "download $file ..."
wget -c "$base_url/$file" -O "$target_dir/$file"
echo "download $file finished"
echo
done
echo "Successfully download all the sampler-meta"
================================================
FILE: tools/pack_dataset.py
================================================
import json
from nit.data.pack import pack_dataset
import argparse
def create_pack(data_meta, max_seq_len, algorithm, split):
max_seq_per_pack = max_seq_len
with open(data_meta, 'r') as fp:
ori_dataset = [json.loads(line) for i, line in enumerate(fp)]
dataset_seq_lens = []
dataset_seq_idxs = []
for idx, data in enumerate(ori_dataset):
seq_len = int(data['latent_h']*data['latent_w']) # patch_size=1
dataset_seq_lens.append(seq_len)
dataset_seq_idxs.append(idx)
total_length = len(ori_dataset)
run_length = int(total_length / split)
all_packed_indices = []
for i in range(split):
seq_lens = dataset_seq_lens[i*run_length: (i+1)*run_length]
seq_idxs = dataset_seq_idxs[i*run_length: (i+1)*run_length]
packed_indices = pack_dataset(
algorithm, max_seq_len, max_seq_per_pack, seq_lens, seq_idxs
)
all_packed_indices.extend(packed_indices)
sampler_json_name = data_meta.split('/')[-1].replace('_meta.jsonl', '')
sampler_json_name = f"{sampler_json_name}_{algorithm}_{max_seq_len}.json"
with open(f'datasets/imagenet1k/sampler_meta/{sampler_json_name}', 'w') as fp:
json.dump(all_packed_indices, fp, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# seed
parser.add_argument("--data-meta", type=str, default='datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl')
parser.add_argument("--max-seq-len", type=int, default=16384)
parser.add_argument("--algorithm", type=str, default='LPFHP')
parser.add_argument("--split", type=int, default=1)
args = parser.parse_args()
create_pack(
data_meta=args.data_meta, max_seq_len=args.max_seq_len,
algorithm=args.algorithm, split=args.split
)