Repository: WZDTHU/NiT Branch: main Commit: 02fa292a7cd1 Files: 91 Total size: 565.6 KB Directory structure: gitextract_jijj82e1/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── c2i/ │ │ ├── nit_b_pack_merge_radio_65536.yaml │ │ ├── nit_l_pack_merge_radio_16384.yaml │ │ ├── nit_s_pack_merge_radio_65536.yaml │ │ ├── nit_xl_pack_merge_radio_16384.yaml │ │ └── nit_xxl_pack_merge_radio_8192.yaml │ └── preprocess/ │ ├── imagenet1k_256x256.yaml │ ├── imagenet1k_512x512.yaml │ └── imagenet1k_native_resolution.yaml ├── nit/ │ ├── data/ │ │ ├── pack/ │ │ │ ├── __init__.py │ │ │ ├── ennlshp.py │ │ │ ├── lpfhp.py │ │ │ ├── nnlshp.py │ │ │ └── spfhp.py │ │ ├── packed_c2i_data.py │ │ └── sampler_util.py │ ├── models/ │ │ ├── c2i/ │ │ │ └── nit_model.py │ │ ├── nvidia_radio/ │ │ │ ├── hubconf.py │ │ │ └── radio/ │ │ │ ├── __init__.py │ │ │ ├── adaptor_base.py │ │ │ ├── adaptor_generic.py │ │ │ ├── adaptor_mlp.py │ │ │ ├── adaptor_registry.py │ │ │ ├── block.py │ │ │ ├── cls_token.py │ │ │ ├── common.py │ │ │ ├── conv.py │ │ │ ├── dinov2_arch.py │ │ │ ├── dual_hybrid_vit.py │ │ │ ├── enable_cpe_support.py │ │ │ ├── enable_damp.py │ │ │ ├── enable_spectral_reparam.py │ │ │ ├── eradio_model.py │ │ │ ├── extra_models.py │ │ │ ├── extra_timm_models.py │ │ │ ├── feature_normalizer.py │ │ │ ├── forward_intermediates.py │ │ │ ├── hf_model.py │ │ │ ├── input_conditioner.py │ │ │ ├── open_clip_adaptor.py │ │ │ ├── radio_model.py │ │ │ ├── vision_transformer_xpos.py │ │ │ ├── vit_patch_generator.py │ │ │ └── vitdet.py │ │ └── utils/ │ │ ├── convs.py │ │ ├── funcs.py │ │ ├── norms.py │ │ └── pos_embeds/ │ │ ├── flash_attn_rotary.py │ │ ├── rope.py │ │ └── sincos.py │ ├── schedulers/ │ │ └── flow_matching/ │ │ ├── loss.py │ │ └── samplers_c2i.py │ └── utils/ │ ├── __init__.py │ ├── deepspeed_zero_to_fp32.py │ ├── ema.py │ ├── eval_utils.py │ ├── freeze.py │ ├── gpu_memory_monitor.py │ ├── lr_scheduler.py │ ├── misc_utils.py │ ├── model_utils.py │ ├── train_utils.py │ ├── util.py │ ├── video_utils.py │ └── warp_pos_idx.py ├── projects/ │ ├── evaluate/ │ │ └── adm_evaluator.py │ ├── preprocess/ │ │ ├── image_latent_c2i.py │ │ └── image_nr_latent_c2i.py │ ├── sample/ │ │ └── sample_c2i_ddp.py │ └── train/ │ └── packed_trainer_c2i.py ├── requirements.txt ├── scripts/ │ ├── preprocess/ │ │ ├── preorocess_in1k_256x256.sh │ │ ├── preorocess_in1k_512x512.sh │ │ └── preorocess_in1k_native_resolution.sh │ ├── sample/ │ │ ├── sample_256x256.sh │ │ ├── sample_512x512.sh │ │ └── sample_768x768.sh │ └── train/ │ ├── train_b_model.sh │ ├── train_l_model.sh │ ├── train_s_model.sh │ ├── train_xl_model.sh │ └── train_xxl_model.sh ├── setup.py └── tools/ ├── download_dataset_256x256.sh ├── download_dataset_512x512.sh ├── download_dataset_data_meta.sh ├── download_dataset_native_resolution.sh ├── download_dataset_sampler_meta.sh └── pack_dataset.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # Ruff stuff: .ruff_cache/ # PyPI configuration file .pypirc *.json *.png *.jpg /checkpoints /workdir /datasets /wandb /samples ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Native-Resolution Image Synthesis

ZiDong Wang1,2·Lei Bai2,*·Xiangyu Yue1·Wanli Ouyang1,2·Yiyuan Zhang1,2,* 1 MMLab CUHK   2Shanghai AI Lab
*Correspondance  

[project page]  [arXiv]  [Dataset]  [Model] 


Summary: We propose Native-resolution diffusion Transformer (NiT), a model that explicitly learns varing resolutions and aspect ratios within its denoising process. This significantly improves training efficiency and generalization capability. To the best of our knowledge, NiT firstly attains SOTA results on both $256\times256$ ($2.08$ FID) and $512\times512$ ($1.48$ FID) benchmarks in class-guided ImageNet generation. NiT can also generalizes to arbitrary resolutions and aspect ratios, such as $4.52$ FID on $1024\times1024$ resolution, $4.11$ FID on $432\times768$ resolution. ![Figure](./assets/teaser.png) ### 🚨 News - `2025-9-18` NiT is accepted by NeurIPS 2025! 🍺 - `2025-6-3` We are delighted to introduce NiT, which is the first work to explicitly model native resolution image synthesis. We have released the code, pretrained models, and processed dataset of NiT. ### 1. Setup First, clone the repo: ```bash git clone https://github.com/WZDTHU/NiT.git && cd NiT ``` #### 1.1 Environment Setup ```bash conda create -n nit_env python=3.10 pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118 pip install flash-attn pip install -r requirements.txt pip install -e . ``` #### 1.2 Model Zoo (WIP) With a single model, NiT-XL can compete on multiple benchmarks and it achieves a dual SOTA on both ImageNet-$256\times256$ and $512\times512$ benchmarks. | Model | Model Zoo | Model Size | FID-256x256 | FID-512x512 | FID-768x768 | FID-1024x1024 | |---------------|------------|---------|------------|------------|------------|------------| | NiT-XL-1000K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors) | 675M | 2.16 | 1.57 | 4.05 | 4.52 | | NiT-XL-1500K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors) | 675M | 2.03 | 1.45 | - | - | ```bash mkdir checkpoints wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors" -O checkpoints/nit_xl_model_1000K.safetensors wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors" -O checkpoints/nit_xl_model_1500K.safetensors ``` ### 2. Sampling #### 2.1 Sampling Hyper-parameters The sampling hyper-parameters for NiT-XL-1000K are summarized as follows: | Resolution | Solver | NFE | CFG - scale | CFG - interval | FID | sFID | IS | Prec. | Rec. | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | 256 × 256 | SDE | 250 | 2.25 | [0.0, 0.7] | 2.16 | 6.34 | 253.44 | 0.79 | 0.62 | | 512 × 512 | SDE | 250 | 2.05 | [0.0, 0.7] | 1.57 | 4.13 | 260.69 | 0.81 | 0.63 | | 768 × 768 | ODE | 50 | 3.0 | [0.0, 0.7] | 4.05 | 8.77 | 262.31 | 0.83 | 0.52 | | 1024 × 1024 | ODE | 50 | 3.0 | [0.0, 0.8] | 4.52 | 7.99 | 286.87 | 0.82 | 0.50 | | 1536 × 1536 | ODE | 50 | 3.5 | [0.0, 0.9] | 6.51 | 9.97 | 230.10 | 0.83 | 0.42 | | 2048 × 2048 | ODE | 50 | 4.5 | [0.0, 0.9] | 24.76 | 18.02 | 131.36 | 0.67 | 0.46 | | 320 × 960 | ODE | 50 | 4.0 | [0.0, 0.9] | 16.85 | 17.79 | 189.18 | 0.71 | 0.38 | | 432 × 768 | ODE | 50 | 2.75 | [0.0, 0.7] | 4.11 | 10.30 | 254.71 | 0.83 | 0.55 | | 480 × 640 | ODE | 50 | 2.75 | [0.0, 0.7] | 3.72 | 8.23 | 284.94 | 0.83 | 0.54 | | 640 × 480 | ODE | 50 | 2.5 | [0.0, 0.7] | 3.41 | 8.07 | 259.06 | 0.83 | 0.56 | | 768 × 432 | ODE | 50 | 2.85 | [0.0, 0.7] | 5.27 | 9.92 | 218.78 | 0.80 | 0.55 | | 960 × 320 | ODE | 50 | 4.5 | [0.0, 0.9] | 9.90 | 25.78 | 255.95 | 0.74 | 0.40 | #### 2.2 Sampling Scripts Sampling with NiT-XL-1000K model for $256\times256$-resolution images: ```bash bash scripts/sample/sample_256x256.sh ``` Sampling with NiT-XL-1000K model for $512\times512$-resolution images: ```bash bash scripts/sample/sample_512x512.sh ``` Sampling with NiT-XL-1000K model for $768\times768$-resolution images: ```bash bash scripts/sample/sample_768x768.sh ``` ### 3. Evaluation The sampling generates a folder of samples to compute FID, Inception Score and other metrics. Note that we do not pack the generate samples as a `.npz` file, this does not affect the calculation of FID and other metrics. Please follow the [ADM's TensorFlow evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations) to setup the conda-environment and download the reference batch. ```bash wget -c "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" -O checkpoints/classify_image_graph_def.pb ``` Given the directory of the reference batch `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes: ```bash python projects/evaluate/adm_evaluator.py $REFERENCE_DIR $SAMPLING_DIR ``` ### 4. Training #### 4.1 Dataset Setup Currently, we provide all the [preprocessed dataset](https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K) for ImageNet1K. Please use the following commands to download the meta files and preprocessed latents. ```bash mkdir datasets mkdir datasets/imagenet1k bash tools/download_dataset_256x256.sh bash tools/download_dataset_512x512.sh bash tools/download_dataset_native_resolution.sh ``` #### Preprocess ImageNet1K Locally You can also preprocess the ImageNet1K dataset on your own. Take $256\times256$-image preprocess as example, you should first modify the `data_dir` as your local ImageNet1K directory in `configs/preprocess/imagenet1k_256x256.yaml`. Then run the preprocess script `scripts/preprocess/preorocess_in1k_256x256.sh`. ```bash bash scripts/preprocess/preorocess_in1k_256x256.sh ``` The proprecessing procedure of $512\times512$-image and native-resolution image is similiar. Modify the corresponding config file and run the script. ```bash bash scripts/preprocess/preorocess_in1k_512x512.sh bash scripts/preprocess/preorocess_in1k_native_resolution.sh ``` #### 4.2 Packing As we pack multiple image instances with distinct resolution into one sequence, we need to pre-set the image indices of each pack before the training process. #### Download meta-info files Down all the data-meta files firstly, which restore the height, width and other information of each image. ```bash bash tools/download_dataset_data_meta.sh ``` The above command will download four the data-meta files on `datasets/imagenet1k/data_meta` directory: - `dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl`: data-meta file for $256\times256$-resolution image data. - `dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl`, data-meta file for $512\times512$-resolution image data. - `dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`, data-meta file for native-resolution image data. - `dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl`, a merged file of the above three files. The first two items of the native-resolution-image data-meta file (`dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`) are as follows: ```json {"image_file": "n01601694/n01601694_11629.JPEG", "latent_file": "n01601694/n01601694_11629.safetensors", "ori_w": 580, "ori_h": 403, "latent_h": 12, "latent_w": 18, "image_h": 384, "image_w": 576, "type": "native-resolution"} {"image_file": "n01601694/n01601694_11799.JPEG", "latent_file": "n01601694/n01601694_11799.safetensors", "ori_w": 500, "ori_h": 350, "latent_h": 10, "latent_w": 15, "image_h": 320, "image_w": 480, "type": "native-resolution"} ``` #### Sampler-Meta Download Given the maximum length $L$, we pre-set the image indices of each pack before training. Here we use the LPFHP (longest-pack-first histogram packing) algorithm to pack all the dataset. You can download our preprocessed packed sampler-meta file using the following command. ```bash bash tools/download_dataset_sampler_meta.sh ``` The above command will download three the data-meta files on `datasets/imagenet1k/sampler_meta` directory: - `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json`: corresponds to $L=16384$. - `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json`: corresponds to $L=16384$. This is the setting in NiT-XL experiments. - `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json`, corresponds to $L=32768$. - `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json`, corresponds to $L=65536$. #### Prepare the Packing (Sampler-Meta) on Your Own NiT supports training with images of arbitrary resolutions and aspect ratios, you can also prepare the packing (sampler-meta) according to your own demands. ```bash # generate the default sampler-meta python tools/pack_dataset.py # generate the sampelr-meta for fixed 256x256-resolution experiment with the maximum sequence length of 16384 python tools/pack_dataset.py --data-meta datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl --max-seq-len 16384 ``` #### Download Image Encoder For NiT-S (33M) model, we use RADIO-v2.5-H as image encoder for REPA-loss. For other NiT models, we use RADIO-v2.5-H as our image encoder. ```bash wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar" -O checkpoints/radio_v2.5-h.pth.tar wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar" -O checkpoints/radio-v2.5-b_half.pth.tar ``` #### Training Scripts The above steps setup the `packed_json`, `jsonl_dir`, and `latent_dirs` in `configs/c2i/nit_xl_pack_merge_radio_16384.yaml`. Before training, please specify the `image_dir` as the directory of ImageNet1K dataset in your own machine. To train the XL-model (675M): ```bash bash scripts/train/train_xl_model.sh ``` Specify the `image_dir` in `configs/c2i/nit_s_pack_merge_radio_65536.yaml` and train the base-model (131M): ```bash bash scripts/train/train_s_model.sh ``` Specify the `image_dir` in `configs/c2i/nit_b_pack_merge_radio_65536.yaml` and train the base-model (131M): ```bash bash scripts/train/train_b_model.sh ``` Specify the `image_dir` in `configs/c2i/nit_l_pack_merge_radio_16384.yaml` and train the base-model (457M): ```bash bash scripts/train/train_l_model.sh ``` Specify the `image_dir` in `configs/c2i/nit_xxl_pack_merge_radio_8192.yaml` and train the xxl-model (1.37B): ```bash bash scripts/train/train_xxl_model.sh ``` ### Citations If you find the project useful, please kindly cite: ```bibtex @article{wang2025native, title={Native-Resolution Image Synthesis}, author={Wang, Zidong and Bai, Lei and Yue, Xiangyu and Ouyang, Wanli and Zhang, Yiyuan}, year={2025}, eprint={2506.03131}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` ### License This project is licensed under the Apache-2.0 license. ================================================ FILE: configs/c2i/nit_b_pack_merge_radio_65536.yaml ================================================ model: transport: path_type: linear prediction: v weighting: lognormal network: target: nit.models.c2i.nit_model.NiT params: class_dropout_prob: 0.1 num_classes: 1000 depth: 12 hidden_size: 768 patch_size: 1 in_channels: 32 num_heads: 12 qk_norm: True encoder_depth: 4 z_dim: 1280 use_checkpoint: False # pretrained_vae: vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers slice_vae: False tile_vae: False # repa encoder enc_type: radio enc_dir: checkpoints/radio_v2.5-h.pth.tar proj_coeff: 1.0 # ema use_ema: True ema_decay: 0.9999 data: data_type: improved_pack dataset: packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512'] latent_dirs: [ 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512', ] image_dir: /train dataloader: num_workers: 4 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 'c2i'}} max_train_steps: 2000000 checkpointing_steps: 2000 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 5.0e-5 learning_rate_base_batch_size: 1 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 0 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [200000, 500000, 100000, 150000] ================================================ FILE: configs/c2i/nit_l_pack_merge_radio_16384.yaml ================================================ model: transport: path_type: linear prediction: v weighting: lognormal network: target: nit.models.c2i.nit_model.NiT params: class_dropout_prob: 0.1 num_classes: 1000 depth: 24 hidden_size: 1024 patch_size: 1 in_channels: 32 num_heads: 16 qk_norm: True encoder_depth: 6 z_dim: 1280 use_checkpoint: False # pretrained_vae: vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers slice_vae: False tile_vae: False # repa encoder enc_type: radio enc_dir: checkpoints/radio_v2.5-h.pth.tar proj_coeff: 1.0 # ema use_ema: True ema_decay: 0.9999 data: data_type: improved_pack dataset: packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512'] latent_dirs: [ 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512', ] image_dir: /train dataloader: num_workers: 4 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 'c2i'}} max_train_steps: 2000000 checkpointing_steps: 2000 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 5.0e-5 learning_rate_base_batch_size: 4 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 0 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [200000, 500000, 100000, 150000] ================================================ FILE: configs/c2i/nit_s_pack_merge_radio_65536.yaml ================================================ model: transport: path_type: linear prediction: v weighting: lognormal network: target: nit.models.c2i.nit_model.NiT params: class_dropout_prob: 0.1 num_classes: 1000 depth: 12 hidden_size: 384 patch_size: 1 in_channels: 32 num_heads: 6 qk_norm: True encoder_depth: 4 z_dim: 768 projector_dim: 768 use_checkpoint: False # pretrained_vae: vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers slice_vae: False tile_vae: False # repa encoder enc_type: radio enc_dir: checkpoints/radio-v2.5-b_half.pth.tar proj_coeff: 1.0 # ema use_ema: True ema_decay: 0.9999 data: data_type: improved_pack dataset: packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512'] latent_dirs: [ 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512', ] image_dir: /train dataloader: num_workers: 4 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 'c2i'}} max_train_steps: 2000000 checkpointing_steps: 2000 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 5.0e-5 learning_rate_base_batch_size: 1 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 0 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [200000, 500000, 100000, 150000] ================================================ FILE: configs/c2i/nit_xl_pack_merge_radio_16384.yaml ================================================ model: transport: path_type: linear prediction: v weighting: lognormal network: target: nit.models.c2i.nit_model.NiT params: class_dropout_prob: 0.1 num_classes: 1000 depth: 28 hidden_size: 1152 patch_size: 1 in_channels: 32 num_heads: 16 qk_norm: True encoder_depth: 8 z_dim: 1280 use_checkpoint: False # pretrained_vae: vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers slice_vae: False tile_vae: False # repa encoder enc_type: radio enc_dir: checkpoints/radio_v2.5-h.pth.tar proj_coeff: 1.0 # ema use_ema: True ema_decay: 0.9999 data: data_type: improved_pack dataset: packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512'] latent_dirs: [ 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512', ] image_dir: /train dataloader: num_workers: 4 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 'c2i'}} max_train_steps: 2000000 checkpointing_steps: 2000 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 5.0e-5 learning_rate_base_batch_size: 4 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 0 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [200000, 500000, 100000, 150000] ================================================ FILE: configs/c2i/nit_xxl_pack_merge_radio_8192.yaml ================================================ model: transport: path_type: linear prediction: v weighting: lognormal network: target: nit.models.c2i.nit_model.NiT params: class_dropout_prob: 0.1 num_classes: 1000 depth: 40 hidden_size: 1536 patch_size: 1 in_channels: 32 num_heads: 24 qk_norm: True encoder_depth: 8 z_dim: 1280 use_checkpoint: False use_adaln_lora: True adaln_lora_dim: 512 # pretrained_vae: vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers slice_vae: False tile_vae: False # repa encoder enc_type: radio enc_dir: checkpoints/radio_v2.5-h.pth.tar proj_coeff: 1.0 # ema use_ema: True ema_decay: 0.9999 data: data_type: improved_pack dataset: packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512'] latent_dirs: [ 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256', 'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512', ] image_dir: /train dataloader: num_workers: 4 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 'c2i'}} max_train_steps: 1000000 checkpointing_steps: 2000 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 5.0e-5 learning_rate_base_batch_size: 4 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 0 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [200000, 500000, 100000] ================================================ FILE: configs/preprocess/imagenet1k_256x256.yaml ================================================ model: vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers data: dataset: data_dir: /train target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256 resolution: 256 dataloader: num_workers: 1 batch_size: 64 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 't2i'}} max_train_steps: 100000 checkpointing_steps: 200 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 1.0e-4 learning_rate_base_batch_size: 256 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 4000 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [20000, 40000, 60000, 80000] ================================================ FILE: configs/preprocess/imagenet1k_512x512.yaml ================================================ model: vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers data: dataset: data_dir: /train target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512 resolution: 512 dataloader: num_workers: 1 batch_size: 16 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 't2i'}} max_train_steps: 100000 checkpointing_steps: 200 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 1.0e-4 learning_rate_base_batch_size: 256 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 4000 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [20000, 40000, 60000, 80000] ================================================ FILE: configs/preprocess/imagenet1k_native_resolution.yaml ================================================ model: vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers data: dataset: data_dir: /train target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution min_image_size: 32 max_image_size: 2048 dataloader: num_workers: 1 batch_size: 1 # Batch size (per device) for the training dataloader. training: tracker: null tracker_kwargs: {'wandb': {'group': 't2i'}} max_train_steps: 100000 checkpointing_steps: 200 checkpoints_total_limit: 2 resume_from_checkpoint: latest learning_rate: 1.0e-4 learning_rate_base_batch_size: 256 scale_lr: True lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] lr_warmup_steps: 4000 gradient_accumulation_steps: 1 optimizer: target: torch.optim.AdamW params: # betas: ${tuple:0.9, 0.999} betas: [0.9, 0.95] weight_decay: 1.0e-2 eps: 1.0e-6 max_grad_norm: 1.0 proportion_empty_prompts: 0.0 mixed_precision: bf16 # ["no", "fp16", "bf16"] allow_tf32: True validation_steps: 500 checkpoint_list: [20000, 40000, 60000, 80000] ================================================ FILE: nit/data/pack/__init__.py ================================================ from .ennlshp import ENNLSHP from .lpfhp import LPFHP from .nnlshp import NNLSHP from .spfhp import SPFHP import json import torch import numpy as np from tqdm import tqdm def get_strategy(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens): def generate_histogram(dataset_seq_lens): histogram = np.zeros(max_seq_len, dtype=np.int64) seq_lens, counts = np.unique(np.array(dataset_seq_lens), return_counts=True) histogram[seq_lens - 1] = counts return histogram histogram = generate_histogram(dataset_seq_lens) if algorithm == "SPFHP": strategy = SPFHP(histogram, max_seq_len, max_seq_per_pack) elif algorithm == "LPFHP": strategy = LPFHP(histogram, max_seq_len, max_seq_per_pack) elif algorithm == 'ENNLSHP': strategy = ENNLSHP(histogram, max_seq_len, max_seq_per_pack) elif algorithm == 'NNLSHP': strategy = NNLSHP(histogram, max_seq_len, max_seq_per_pack) else: raise NotImplementedError("Algorithm type unsupported. Pass one of: LPFHP, SPFHP") return strategy def pack_dataset(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens, dataset_seq_idxs): dataset_seqs = torch.stack([torch.tensor(dataset_seq_lens), torch.tensor(dataset_seq_idxs)]) strategy_set, strategy_repeat_count = get_strategy( algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens ) packed_indices = [] run_iters = sum(strategy_repeat_count) progress_bar = tqdm(range(run_iters)) for i in range(len(strategy_repeat_count)): strategy = strategy_set[i] for _ in range(strategy_repeat_count[i]): progress_bar.update(1) ref_inds = [] for x in strategy: ref_ind = torch.argwhere(dataset_seqs[0] == x)[-1] dataset_seqs[0, ref_ind] = -1 ref_inds.append(ref_ind) inds = dataset_seqs[1, ref_inds].ravel() packed_indices.append(inds.tolist()) return packed_indices ================================================ FILE: nit/data/pack/ennlshp.py ================================================ # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/ennlshp.py """Extended Non-Negative least squares histogram-packing.""" import time import numpy as np from scipy import optimize, stats from functools import lru_cache def get_packing_matrix(strategy_set, max_sequence_length): num_strategies = len(strategy_set) A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32) for i, strategy in enumerate(strategy_set): for seq_len in strategy: A[seq_len - 1, i] += 1 return A @lru_cache(maxsize=None) def get_packing_strategies(start_length, minimum_increment, target_length, depth): gap = target_length - start_length strategies = [] # Complete the packing with exactly 1 number if depth == 1: if gap >= minimum_increment: strategies.append([gap]) # Complete the sample in "depth" steps, recursively else: for new in range(minimum_increment, gap + 1): new_gap = target_length - start_length - new if new_gap == 0: strategies.append([new]) else: options = get_packing_strategies(start_length + new, new, target_length, depth - 1) for option in options: if len(option) > 0: strategies.append([new] + option) return strategies def ENNLSHP(histogram, max_sequence_length, max_sequences_per_pack): # List all unique ways of packing to the desired maximum sequence length strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack) # Get the packing matrix corresponding to this list of packing strategies A = get_packing_matrix(strategy_set, max_sequence_length) # Weights that penalize the residual by the number of resulting padding tokens. w0 = np.array([x + 1 for x in range(max_sequence_length)]) # construct the packing matrix A_bar = np.zeros((2 * max_sequence_length, len(strategy_set) + max_sequence_length), "d") # Base weighted matrix A_bar[:max_sequence_length, : len(strategy_set)] = np.expand_dims(w0, -1) * A # Higher weight to avoid positive residual A_bar[max_sequence_length:, : len(strategy_set)] = np.expand_dims(10**6 * np.ones([max_sequence_length]), -1) * A # negative diagonal unity matrix for mapping to residual A_bar[max_sequence_length:, len(strategy_set) :] = np.expand_dims( 10**6 * np.ones([max_sequence_length]), -1 ) * np.ones((max_sequence_length, max_sequence_length)) b_bar = np.zeros(2 * max_sequence_length) # Apply weighting to histogram vector b_bar[:max_sequence_length] = w0 * histogram b_bar[max_sequence_length:] = 10**6 * np.ones([max_sequence_length]) * histogram # Solve the packing problem start = time.time() strategy_residual, rnorm = optimize.nnls(A_bar, b_bar) strategy_repeat_count = strategy_residual[: len(strategy_set)] # Round the floating point solution to nearest integer strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64) # Compute the residuals, shape: [max_sequence_length] residual = histogram - A @ strategy_repeat_count # Handle the left-over sequences; that is the positive part of residual unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0] for l in unpacked_seqlen: strategy = sorted([l, max_sequence_length - l]) # the depth 1 strategy strategy_index = strategy_set.index(strategy) strategy_repeat_count[strategy_index] += residual[l - 1] # Re-compute the residual with the updated strategy_repeat_count # This should now be strictly < 0 residual = histogram - A @ strategy_repeat_count # Add padding based on deficit (negative residual portion of residual) padding = np.where(residual < 0, -residual, 0) # Calculate some basic statistics duration = time.time() - start sequence_lengths = np.arange(1, max_sequence_length + 1) old_number_of_samples = histogram.sum() new_number_of_samples = int(strategy_repeat_count.sum()) speedup_upper_bound = 1.0 / ( 1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples ) num_padding_tokens_packed = (sequence_lengths * padding).sum() efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length) print( f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n", f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n", f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n" f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.", ) return strategy_set, strategy_repeat_count ================================================ FILE: nit/data/pack/lpfhp.py ================================================ # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/lpfhp.py """Longest-pack-first histogram-packing.""" from collections import defaultdict import numpy as np import time def add_pack(pack, count, tmp, final, limit, offset, max_sequence_length=512): """Filter out packs that reached maximum length or number of components.""" # sanity checks assert max_sequence_length - sum(pack) == offset, "Incorrect offset." assert offset >= 0, "Too small offset." assert offset < max_sequence_length, "Too large offset." if len(pack) == limit or offset == 0: final[offset].append((count, pack)) else: tmp[offset].append((count, pack)) def LPFHP(histogram, max_sequence_length, max_sequences_per_pack, distribute=True): """Longest-pack-first histogram-packing.""" start = time.time() reversed_histogram = np.flip(histogram) # Initialize main strategy data dictionary. # The key indicates how many tokens are left for full length. # The value is a list of tuples, consisting of counts and respective packs. # A pack is a (sorted) list of sequence length values that get concatenated. tmp_strategies_per_length = defaultdict(list) strategies_per_length = defaultdict(list) if max_sequences_per_pack == "max": max_sequences_per_pack = max_sequence_length # Index i indicates here, how much space is left, due to reversed histogram for i in range(max_sequence_length): n_sequences_to_bin = reversed_histogram[i] length_to_bin = max_sequence_length - i offset = 0 # smallest possible offset for perfect fit while n_sequences_to_bin > 0: if (length_to_bin + offset) in tmp_strategies_per_length: # extract worst pack that will get modified n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop() # calculate how often the current sequence maximally fits in repeat = min(1 + offset // length_to_bin, max_sequences_per_pack - len(pack)) # correct dependent on count while n_sequences_to_bin // repeat == 0: repeat -= 1 if not distribute: repeat = 1 new_pack = pack + [length_to_bin] * repeat count = min(n_sequences_to_pack, n_sequences_to_bin // repeat) if n_sequences_to_pack > count: # old pack gets reduced n_sequences_to_pack -= count tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack)) n_sequences_to_bin -= count * repeat else: n_sequences_to_bin -= n_sequences_to_pack * repeat add_pack( new_pack, count, tmp_strategies_per_length, strategies_per_length, max_sequences_per_pack, offset - (repeat - 1) * length_to_bin, max_sequence_length, ) # clean up to speed up main key search if not tmp_strategies_per_length[length_to_bin + offset]: tmp_strategies_per_length.pop(length_to_bin + offset) # reset offset in case best fit changed offset = 0 else: offset += 1 # Does not fit anywhere. Create new pack. if offset >= max_sequence_length - length_to_bin + 1: # similar repetition but no dependence on pack. repeat = min(max_sequence_length // length_to_bin, max_sequences_per_pack) while n_sequences_to_bin // repeat == 0: repeat -= 1 if not distribute: repeat = 1 add_pack( [length_to_bin] * repeat, n_sequences_to_bin // repeat, tmp_strategies_per_length, strategies_per_length, max_sequences_per_pack, max_sequence_length - length_to_bin * repeat, max_sequence_length, ) n_sequences_to_bin -= n_sequences_to_bin // repeat * repeat # merge all strategies for key in tmp_strategies_per_length: strategies_per_length[key].extend(tmp_strategies_per_length[key]) # flatten strategies dictionary strategy_set = [] strategy_repeat_count = [] for key in strategies_per_length: for count, pack in strategies_per_length[key]: pack.reverse() strategy_set.append(pack) strategy_repeat_count.append(count) # Summarize efficiency of solution duration = time.time() - start sequence_lengths = np.arange(1, max_sequence_length + 1) strategy_repeat_count = np.array(strategy_repeat_count) n_strategies = len(strategy_set) old_number_of_samples = histogram.sum() new_number_of_samples = strategy_repeat_count.sum() sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)]) total_tokens = max_sequence_length * new_number_of_samples empty_tokens = sum( [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)] ) efficiency = 100 - empty_tokens / total_tokens * 100 speedup_upper_bound = 1.0 / ( 1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples ) print( f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n", f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n", f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}", f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.", ) return strategy_set, strategy_repeat_count ================================================ FILE: nit/data/pack/nnlshp.py ================================================ # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/nnlshp.py """Non-Negative least squares histogram-packing.""" import time import numpy as np from scipy import optimize, stats from functools import lru_cache def get_packing_matrix(strategy_set, max_sequence_length): num_strategies = len(strategy_set) A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32) for i, strategy in enumerate(strategy_set): for seq_len in strategy: A[seq_len - 1, i] += 1 return A @lru_cache(maxsize=None) def get_packing_strategies(start_length, minimum_increment, target_length, depth): gap = target_length - start_length strategies = [] # Complete the packing with exactly 1 number if depth == 1: if gap >= minimum_increment: strategies.append([gap]) # Complete the sample in "depth" steps, recursively else: for new in range(minimum_increment, gap + 1): new_gap = target_length - start_length - new if new_gap == 0: strategies.append([new]) else: options = get_packing_strategies(start_length + new, new, target_length, depth - 1) for option in options: if len(option) > 0: strategies.append([new] + option) return strategies def NNLSHP(histogram, max_sequence_length, max_sequences_per_pack): # List all unique ways of packing to the desired maximum sequence length strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack) # Get the packing matrix corresponding to this list of packing strategies A = get_packing_matrix(strategy_set, max_sequence_length) # Weights that penalize the residual on short sequences less. penalization_cutoff = 8 w0 = np.ones([max_sequence_length]) w0[:penalization_cutoff] = 0.09 # Solve the packing problem start = time.time() strategy_repeat_count, rnorm = optimize.nnls(np.expand_dims(w0, -1) * A, w0 * histogram) # Round the floating point solution to nearest integer strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64) # Compute the residuals, shape: [max_sequence_length] residual = histogram - A @ strategy_repeat_count # Handle the left-over sequences, that is the positive part of residual unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0] for l in unpacked_seqlen: strategy = sorted([l, max_sequence_length - l]) # the depth 1 strategy strategy_index = strategy_set.index(strategy) strategy_repeat_count[strategy_index] += residual[l - 1] # Re-compute the residual with the updated strategy_repeat_count # This should now be strictly < 0 residual = histogram - A @ strategy_repeat_count # Add padding based on deficit (negative residual portion of residual) padding = np.where(residual < 0, -residual, 0) # Calculate some basic statistics duration = time.time() - start sequence_lengths = np.arange(1, max_sequence_length + 1) old_number_of_samples = histogram.sum() new_number_of_samples = int(strategy_repeat_count.sum()) speedup_upper_bound = 1.0 / ( 1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples ) num_padding_tokens_packed = (sequence_lengths * padding).sum() efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length) print( f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n", f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n", f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n" f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.", ) return strategy_set, strategy_repeat_count ================================================ FILE: nit/data/pack/spfhp.py ================================================ # Copyright (c) 2021 Graphcore Ltd. All rights reserved. # modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/spfhp.py """Shortest-pack-first histogram-packing.""" from collections import defaultdict import numpy as np import time def add_pack(pack, count, tmp, final, limit, offset): """Filter out packs that reached maximum length or number of sequences.""" if len(pack) == limit or offset == 0: final[offset].append((count, pack)) else: tmp[offset].append((count, pack)) def SPFHP(histogram, max_sequence_length, max_sequences_per_pack): """Shortest-pack-first histogram-packing.""" start = time.time() reversed_histogram = np.flip(histogram) # Initialize main strategy data dictionary. # The key indicates how many tokens are left for full length. # The value is a list of tuples, consisting of counts and respective packs. # A pack is a (sorted) list of sequence length values that get concatenated. tmp_strategies_per_length = defaultdict(list) strategies_per_length = defaultdict(list) # Index i indicates here, how much space is left, due to reversed histogram for i in range(max_sequence_length): n_sequences_to_bin = reversed_histogram[i] length_to_bin = max_sequence_length - i offset = i + 1 # largest possible offset while n_sequences_to_bin > 0: if (length_to_bin + offset) in tmp_strategies_per_length: # extract shortest pack that will get modified n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop() new_pack = pack + [length_to_bin] count = min(n_sequences_to_pack, n_sequences_to_bin) if n_sequences_to_pack > n_sequences_to_bin: # old pack gets reduced n_sequences_to_pack -= n_sequences_to_bin tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack)) n_sequences_to_bin = 0 else: n_sequences_to_bin -= n_sequences_to_pack add_pack( new_pack, count, tmp_strategies_per_length, strategies_per_length, max_sequences_per_pack, offset, ) # clean up to speed up main key search if not tmp_strategies_per_length[length_to_bin + offset]: tmp_strategies_per_length.pop(length_to_bin + offset) else: offset -= 1 # Does not fit anywhere. Create new pack. if offset < 0: add_pack( [length_to_bin], n_sequences_to_bin, tmp_strategies_per_length, strategies_per_length, max_sequences_per_pack, i, ) n_sequences_to_bin = 0 # merge all strategies for key in tmp_strategies_per_length: strategies_per_length[key].extend(tmp_strategies_per_length[key]) # flatten strategies dictionary strategy_set = [] strategy_repeat_count = [] for key in strategies_per_length: for count, pack in strategies_per_length[key]: pack.reverse() strategy_set.append(pack) strategy_repeat_count.append(count) # Summarize efficiency of solution duration = time.time() - start sequence_lengths = np.arange(1, max_sequence_length + 1) strategy_repeat_count = np.array(strategy_repeat_count) n_strategies = len(strategy_set) old_number_of_samples = histogram.sum() new_number_of_samples = strategy_repeat_count.sum() sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)]) total_tokens = max_sequence_length * new_number_of_samples empty_tokens = sum( [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)] ) efficiency = 100 - empty_tokens / total_tokens * 100 speedup_upper_bound = 1.0 / ( 1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples ) print( f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n", f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n", f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n", f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.", ) return strategy_set, np.array(strategy_repeat_count) ================================================ FILE: nit/data/packed_c2i_data.py ================================================ import os import datetime import torchvision import numpy as np import torch import ast import json import time from omegaconf import OmegaConf from tqdm import tqdm from PIL import Image from torch.utils.data import DataLoader, Dataset from torchvision.datasets import ImageFolder from torchvision import transforms from accelerate.logging import get_logger from safetensors.torch import load_file from einops import rearrange from functools import partial from torchvision.transforms.functional import hflip from .sampler_util import get_train_sampler, get_packed_batch_sampler logger = get_logger(__name__, log_level="INFO") PATCH_SIZE = 1 def resize_arr(pil_image, height, width): pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC) return pil_image def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def packed_collate_fn(batch): packed_latent = [] label = [] hw_list = [] image = [] for data in batch: C, H, W = data['latent'].shape latent = rearrange( data['latent'], 'c (h p1) (w p2) -> (h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE ) packed_latent.append(latent) hw_list.append([H/PATCH_SIZE, W/PATCH_SIZE]) label.append(data['label']) image.append(data['image']) packed_latent = torch.concat(packed_latent) label = torch.tensor(label) hw_list = torch.tensor(hw_list, dtype=torch.int32) return dict(image=image, latent=packed_latent, label=label, hw_list=hw_list) class ImprovedPackedImageNetLatentDataset(Dataset): def __init__(self, packed_json, jsonl_dir, data_types, latent_dirs, image_dir): super().__init__() assert len(data_types) == len(latent_dirs) self.type_to_dir = dict() for i, data_type in enumerate(data_types): self.type_to_dir[data_type] = latent_dirs[i] self.image_dir = image_dir with open(packed_json, 'r') as fp: self.packed_dataset = json.load(fp) with open(jsonl_dir, 'r') as fp: self.dataset = [json.loads(line) for line in fp] def __len__(self): return len(self.dataset) def __getitem__(self, index): data_meta = self.dataset[index] data_item = dict() data_type = data_meta['type'] latent_file = os.path.join(self.type_to_dir[data_type], data_meta['latent_file']) image_file = os.path.join(self.image_dir, data_meta['image_file']) data = load_file(latent_file) height = data_meta['latent_h'] * 16 width = data_meta['latent_w'] * 16 if data_type == 'native-resolution': preprocess = partial(resize_arr, height=height, width=width) else: assert height == width preprocess = partial(center_crop_arr, image_size=height) transform = transforms.Compose([ transforms.Lambda(lambda pil_image: preprocess(pil_image=pil_image)), transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))), transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) rand_idx = torch.randint(low=0, high=2, size=(1,)).item() data_item['image'] = transform(Image.open(image_file).convert("RGB"))[rand_idx] data_item['latent'] = data['latent'][rand_idx] data_item['label'] = data['label'] return data_item class C2ILoader(): def __init__(self, data_config): super().__init__() self.batch_size = data_config.dataloader.batch_size self.num_workers = data_config.dataloader.num_workers self.data_type = data_config.data_type if data_config.data_type == 'improved_pack': self.train_dataset = ImprovedPackedImageNetLatentDataset( **OmegaConf.to_container(data_config.dataset) ) else: raise NotImplementedError self.test_dataset = None self.val_dataset = None def train_len(self): return len(self.train_dataset) def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed): sampler = get_train_sampler( self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed ) if self.data_type == 'improved_pack': batch_sampler = get_packed_batch_sampler( self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed ) return DataLoader( self.train_dataset, batch_sampler=batch_sampler, collate_fn=packed_collate_fn, num_workers=self.num_workers, pin_memory=True, ) else: return DataLoader( self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, pin_memory=True, drop_last=True, ) def test_dataloader(self): return None def val_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, pin_memory=True, drop_last=True ) ================================================ FILE: nit/data/sampler_util.py ================================================ import torch import json # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60 def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps, resume_step, seed): sample_indices = torch.empty([max_steps * global_batch_size // world_size], dtype=torch.long) epoch_id, fill_ptr, offs = 0, 0, 0 while fill_ptr < sample_indices.size(0): g = torch.Generator() g.manual_seed(seed + epoch_id) epoch_sample_indices = torch.randperm(len(dataset), generator=g) epoch_id += 1 epoch_sample_indices = epoch_sample_indices[ (rank + offs) % world_size::world_size ] offs = (offs + world_size - len(dataset) % world_size) % world_size epoch_sample_indices = epoch_sample_indices[ :sample_indices.size(0) - fill_ptr ] sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \ epoch_sample_indices fill_ptr += epoch_sample_indices.size(0) return sample_indices[resume_step * global_batch_size // world_size:].tolist() def get_packed_batch_sampler( dataset, rank, world_size, max_steps, resume_step, seed ): sample_indices = [None for _ in range(max_steps)] epoch_id, fill_ptr, offs = 0, 0, 0 while fill_ptr < len(sample_indices): g = torch.Generator() g.manual_seed(seed + epoch_id) epoch_sample_indices = torch.randperm(len(dataset), generator=g) epoch_id += 1 epoch_sample_indices = epoch_sample_indices[ (rank + offs) % world_size::world_size ] offs = (offs + world_size - len(dataset) % world_size) % world_size epoch_sample_indices = epoch_sample_indices[ :len(sample_indices) - fill_ptr ] sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [ dataset[i] for i in epoch_sample_indices ] fill_ptr += epoch_sample_indices.size(0) return sample_indices[resume_step:] ================================================ FILE: nit/models/c2i/nit_model.py ================================================ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- import torch import torch.nn as nn import numpy as np import math from timm.models.vision_transformer import PatchEmbed, Mlp from einops import rearrange, repeat from flash_attn import flash_attn_varlen_func from nit.models.utils.funcs import get_parameter_dtype from nit.models.utils.pos_embeds.rope import VisionRotaryEmbedding, rotate_half from typing import Optional def modulate(x, shift, scale): return x * (1 + scale) + shift def build_mlp(hidden_size, projector_dim, z_dim): return nn.Sequential( nn.Linear(hidden_size, projector_dim), nn.SiLU(), nn.Linear(projector_dim, projector_dim), nn.SiLU(), nn.Linear(projector_dim, z_dim), ) ################################################################################# # Embedding Layers for Timesteps and Class Labels # ################################################################################# class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def positional_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): self.timestep_embedding = self.positional_embedding t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype) t_emb = self.mlp(t_freq) return t_emb class LabelEmbedder(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def forward(self, labels): embeddings = self.embedding_table(labels) return embeddings ################################################################################# # Attention Block # ################################################################################# class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, cu_seqlens, freqs_cos, freqs_sin) -> torch.Tensor: N, C = x.shape qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3) ori_dtype = qkv.dtype q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) q = q * freqs_cos + rotate_half(q) * freqs_sin k = k * freqs_cos + rotate_half(k) * freqs_sin q, k = q.to(ori_dtype), k.to(ori_dtype) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() x = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(N, -1) x = self.proj(x) x = self.proj_drop(x) return x ################################################################################# # Core NiT Model # ################################################################################# class NiTBlock(nn.Module): """ A NiT block with adaptive layer norm zero (adaLN-Zero) conditioning. """ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): super().__init__() self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.attn = Attention( hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"] ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate="tanh") self.mlp = Mlp( in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0 ) use_adaln_lora = block_kwargs.get('use_adaln_lora', False) if use_adaln_lora: adaln_lora_dim = block_kwargs['adaln_lora_dim'] self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, adaln_lora_dim, bias=True), nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True) ) else: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) ) def forward(self, x, c, cu_seqlens, freqs_cos, freqs_sin): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.adaLN_modulation(c).chunk(6, dim=-1) ) x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin) x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) return x class FinalLayer(nn.Module): """ The final layer of NiT. """ def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class NiT(nn.Module): """ Diffusion model with a Transformer backbone. """ def __init__( self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.1, num_classes=1000, encoder_depth=4, projector_dim=2048, z_dim=768, use_checkpoint: bool = False, custom_freqs: str = 'normal', theta: int = 10000, max_pe_len_h: Optional[int] = None, max_pe_len_w: Optional[int] = None, decouple: bool = False, ori_max_pe_len: Optional[int] = None, **block_kwargs # fused_attn ): super().__init__() self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size self.num_heads = num_heads self.num_classes = num_classes self.encoder_depth = encoder_depth self.use_checkpoint = use_checkpoint self.x_embedder = PatchEmbed( input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False ) self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) self.rope = VisionRotaryEmbedding( head_dim=hidden_size//num_heads, custom_freqs=custom_freqs, theta=theta, max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple, ori_max_pe_len=ori_max_pe_len ) self.projector = build_mlp(hidden_size, projector_dim, z_dim) self.blocks = nn.ModuleList([ NiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth) ]) self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): w = self.x_embedder.proj.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) nn.init.constant_(self.x_embedder.proj.bias, 0) # Initialize label embedding table: nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) # Initialize timestep embedding MLP: nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers in NiT blocks: for block in self.blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def unpatchify(self, x, patch_size=None): """ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p = self.x_embedder.patch_size[0] if patch_size is None else patch_size h = w = int(x.shape[1] ** 0.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def get_rope(self, hw_list): grids = [] for h, w in hw_list: grid_h = torch.arange(h) grid_w = torch.arange(w) grid = torch.meshgrid(grid_h, grid_w, indexing='xy') grid = torch.stack(grid, dim=0).reshape(2, -1) grids.append(grid) grids = torch.cat(grids, dim=-1) freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grids) return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) def forward(self, x, t, y, hw_list, return_zs=False, return_logvar=False): """ Forward pass of NiT. x: (N, C, p, p) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels """ x = self.x_embedder(x) # (N, C, p, p) -> (N, 1, D), where T = H * W / patch_size ** 2 x = x.squeeze(1) # (N, D) B = hw_list.shape[0] freqs_cos, freqs_sin = self.get_rope(hw_list) # (N, D_h) seqlens = hw_list[:, 0] * hw_list[:, 1] cu_seqlens = torch.cat([ torch.tensor([0], device=hw_list.device, dtype=torch.int), torch.cumsum(seqlens, dim=0, dtype=torch.int) ]) # timestep and class embedding t_embed = self.t_embedder(t) # (B, D) y = self.y_embedder(y) # (B, D) c = t_embed + y # (B, D) # (B, D) -> (N, D) c = torch.cat([c[i].unsqueeze(0).repeat(seqlens[i], 1) for i in range(B)], dim=0) zs=[] for i, block in enumerate(self.blocks): if not self.use_checkpoint: x = block(x, c, cu_seqlens, freqs_cos, freqs_sin) # (N, D) else: x = torch.utils.checkpoint.checkpoint( self.ckpt_wrapper(block), x, c, cu_seqlens, freqs_cos, freqs_sin ) if (i + 1) == self.encoder_depth and return_zs: zs = [self.projector(x)] x = self.final_layer(x, c) # (N, out_channels * patch_size ** 2) # (N, out_channels * patch_size ** 2) -> (N, out_channels, p, p) x = rearrange(x, 'n (c p1 p2) -> n c p1 p2', p1=self.patch_size, p2=self.patch_size) if return_zs: return x, zs else: return x def ckpt_wrapper(self, module): def ckpt_forward(*inputs): outputs = module(*inputs) return outputs return ckpt_forward @property def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ return get_parameter_dtype(self) ================================================ FILE: nit/models/nvidia_radio/hubconf.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. dependencies = ["torch", "timm", "einops"] import os from typing import Dict, Any, Optional, Union, List import warnings import torch from torch.hub import load_state_dict_from_url from timm.models import clean_state_dict from .radio.adaptor_registry import adaptor_registry from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP from .radio.enable_damp import configure_damp_from_args from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer from .radio.radio_model import RADIOModel, create_model_from_args from .radio.input_conditioner import get_default_conditioner from .radio.vitdet import apply_vitdet_arch, VitDetArgs def radio_model( version: str = "", progress: bool = True, adaptor_names: Union[str, List[str]] = None, vitdet_window_size: Optional[int] = None, return_checkpoint: bool = False, support_packing: bool=False, **kwargs, ) -> RADIOModel: if not version: version = DEFAULT_VERSION if os.path.isfile(version): chk = torch.load(version, map_location="cpu", weights_only=False) resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None) else: resource = RESOURCE_MAP[version] chk = load_state_dict_from_url( resource.url, progress=progress, map_location="cpu", weights_only=False, ) if "state_dict_ema" in chk: state_dict = chk["state_dict_ema"] chk['args'].spectral_reparam = False else: state_dict = chk["state_dict"] args = chk["args"] args.support_packing = support_packing mod = create_model_from_args(args) mod_state_dict = get_prefix_state_dict(state_dict, "base_model.") if args.spectral_reparam: configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict) if getattr(args, 'damp', None): configure_damp_from_args(mod, args) state_dict = clean_state_dict(state_dict) key_warn = mod.load_state_dict(mod_state_dict, strict=False) if key_warn.missing_keys: warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}') if key_warn.unexpected_keys: warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}') if chk['args'].spectral_reparam: # Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind # the method is that instead of there being a `weight` tensor for certain Linear layers # in the model, we make it a dynamically computed function. During training, this # helps stabilize the model. However, for downstream use cases, it shouldn't be necessary. # Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)` # once, during this function call, and replace the parametrization with the realized weights. # This makes the model run faster, and also use less memory. disable_spectral_reparam(mod) chk['args'].spectral_reparam = False conditioner = get_default_conditioner() conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner.")) dtype = getattr(chk['args'], 'dtype', torch.float32) mod.to(dtype=dtype) conditioner.dtype = dtype cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True) if cls_token_per_teacher: name_to_idx_map = dict() for i, t in enumerate(chk['args'].teachers): if t.get('use_summary', True): name = t['name'] if name not in name_to_idx_map: name_to_idx_map[name] = i summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64) else: summary_idxs = torch.tensor([0], dtype=torch.int64) if adaptor_names is None: adaptor_names = [] elif isinstance(adaptor_names, str): adaptor_names = [adaptor_names] teachers = chk["args"].teachers adaptors = dict() for adaptor_name in adaptor_names: for tidx, tconf in enumerate(teachers): if tconf["name"] == adaptor_name: break else: raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}') ttype = tconf["type"] pf_idx_head = f'_heads.{tidx}' pf_name_head = f'_heads.{adaptor_name}' pf_idx_feat = f'_feature_projections.{tidx}' pf_name_feat = f'_feature_projections.{adaptor_name}' adaptor_state = dict() for k, v in state_dict.items(): if k.startswith(pf_idx_head): adaptor_state['summary' + k[len(pf_idx_head):]] = v elif k.startswith(pf_name_head): adaptor_state['summary' + k[len(pf_name_head):]] = v elif k.startswith(pf_idx_feat): adaptor_state['feature' + k[len(pf_idx_feat):]] = v elif k.startswith(pf_name_feat): adaptor_state['feature' + k[len(pf_name_feat):]] = v adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state) adaptor.head_idx = tidx if cls_token_per_teacher else 0 adaptors[adaptor_name] = adaptor feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.') feature_normalizer = None if feat_norm_sd: feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype) feature_normalizer.load_state_dict(feat_norm_sd) inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.') inter_feature_normalizer = None if inter_feat_norm_sd: inter_feature_normalizer = IntermediateFeatureNormalizer( *inter_feat_norm_sd['means'].shape[:2], rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3, dtype=dtype ) inter_feature_normalizer.load_state_dict(inter_feat_norm_sd) radio = RADIOModel( mod, conditioner, summary_idxs=summary_idxs, patch_size=resource.patch_size, max_resolution=resource.max_resolution, window_size=vitdet_window_size, preferred_resolution=resource.preferred_resolution, adaptors=adaptors, feature_normalizer=feature_normalizer, inter_feature_normalizer=inter_feature_normalizer, ) if vitdet_window_size is not None: apply_vitdet_arch( mod, VitDetArgs( vitdet_window_size, radio.num_summary_tokens, num_windowed=resource.vitdet_num_windowed, num_global=resource.vitdet_num_global, ), ) if return_checkpoint: return radio, chk return radio def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str): mod_state_dict = { k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix) } return mod_state_dict ================================================ FILE: nit/models/nvidia_radio/radio/__init__.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # Register the adaptors from .adaptor_registry import adaptor_registry from . import open_clip_adaptor from .adaptor_base import AdaptorInput, RadioOutput, AdaptorBase # Enable support for other model types via the timm register_model mechanism from . import extra_timm_models from . import extra_models from . import vision_transformer_xpos ================================================ FILE: nit/models/nvidia_radio/radio/adaptor_base.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace from typing import NamedTuple, Optional import torch from torch import nn import torch.nn.functional as F class AdaptorInput(NamedTuple): images: torch.Tensor summary: torch.Tensor features: torch.Tensor feature_fmt: str patch_size: int class RadioOutput(NamedTuple): summary: torch.Tensor features: torch.Tensor def to(self, *args, **kwargs): return RadioOutput( self.summary.to(*args, **kwargs) if self.summary is not None else None, self.features.to(*args, **kwargs) if self.features is not None else None, ) class AdaptorBase(nn.Module): def forward(self, input: AdaptorInput) -> RadioOutput: raise NotImplementedError("Subclasses must implement this!") ================================================ FILE: nit/models/nvidia_radio/radio/adaptor_generic.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace import torch from torch import nn import torch.nn.functional as F from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config class GenericAdaptor(AdaptorBase): def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None): super().__init__() extra_args = dict() ups = None ups_rank = None if adaptor_config is not None: ups = adaptor_config.get('fd_upsample_factor', None) ups_rank = adaptor_config.get('fd_upsample_rank', None) elif mlp_config is not None: ups = mlp_config["feature"].get('upsample_factor', None) ups_rank = mlp_config["feature"].get('upsample_rank', None) if ups is not None: extra_args['upsample_factor'] = ups extra_args['upsample_rank'] = ups_rank if state is not None: spectral_heads = getattr(main_config, 'spectral_heads', False) self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads) self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args) else: assert mlp_config is not None, "Config must not be None if state is None" self.head_mlp = create_mlp_from_config( main_config.mlp_version, mlp_config["summary"]["input_dim"], mlp_config["summary"]["hidden_dim"], mlp_config["summary"]["output_dim"], mlp_config["summary"]["num_inner"], ) self.feat_mlp = create_mlp_from_config( main_config.mlp_version, mlp_config["feature"]["input_dim"], mlp_config["feature"]["hidden_dim"], mlp_config["feature"]["output_dim"], mlp_config["feature"]["num_inner"], **extra_args ) def forward(self, input: AdaptorInput) -> RadioOutput: # Convert input'd type to the type of the first parameter of the adaptor. first_param = next(self.parameters()) summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype) feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype) if input.feature_fmt == 'NCHW': feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2]) .permute(0, 3, 1, 2) ) return RadioOutput(summary, feat) ================================================ FILE: nit/models/nvidia_radio/radio/adaptor_mlp.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import math from typing import Dict, Optional import torch from torch import nn from einops import rearrange from timm.models.vision_transformer import Block from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam class MLP(nn.Module): def __init__(self, input_size: int, hidden_size: int, output_size: int, num_inner: int = 0, device: torch.device = None, **kwargs): super(MLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size, device=device) self.norm = nn.LayerNorm(hidden_size, device=device) self.relu = nn.ReLU() inner = [] for _ in range(num_inner): inner.extend([ nn.Linear(hidden_size, hidden_size, device=device), nn.LayerNorm(hidden_size, device=device), nn.ReLU(), ]) if inner: self.inner = nn.Sequential(*inner) else: self.inner = nn.Identity() self.fc2 = nn.Linear(hidden_size, output_size, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.norm(x) x = self.relu(x) x = self.inner(x) x = self.fc2(x) return x class MLP2(nn.Module): def __init__(self, input_size: int, hidden_size: int, output_size: int, num_inner: int = 0, pre_norm: bool = False, device: torch.device = None, upsample_factor: int = 1, upsample_rank: int = None, from_config: bool = False, **kwargs): super().__init__() self.pre_norm = nn.Sequential( nn.LayerNorm(input_size), nn.GELU(), ) if pre_norm else nn.Identity() self.upsample_factor = upsample_factor sq_ups = upsample_factor ** 2 self._real_output_dim = output_size // sq_ups # hidden_size *= upsample_factor # output_size *= (upsample_factor ** 2) self.fc1 = nn.Linear(input_size, hidden_size, device=device) blocks = [] for _ in range(num_inner): blocks.append(nn.Sequential( nn.LayerNorm(hidden_size, device=device), nn.GELU(), nn.Linear(hidden_size, hidden_size, device=device), )) self.blocks = nn.ModuleList(blocks) self.final = nn.Sequential( nn.LayerNorm(hidden_size, device=device), nn.GELU(), nn.Linear(hidden_size, output_size, device=device), ) def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor: x = self.pre_norm(x) x = self.fc1(x) for block in self.blocks: x = x + block(x) x = self.final(x) if self.upsample_factor > 1: if images is None: raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!') if patch_size is None: raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!') h, w = tuple(d // patch_size for d in images.shape[-2:]) x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c', h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor, c=self._real_output_dim) return x MLP_FACTORY = { 'v1': MLP, 'v2': MLP2, } def strip_prefix(state: Dict[str, torch.Tensor], prefix: str): state = { k[len(prefix):]: v for k, v in state.items() if k.startswith(prefix) } return state def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False): state = strip_prefix(state, prefix) weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original' if version == 'v1': hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape output_dim = state[f'fc2.{weight_suffix}'].shape[0] for num_inner in range(1000): k = f'inner.{num_inner}.0.weight' if k not in state: break elif version == 'v2': hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape output_dim = state[f'final.2.{weight_suffix}'].shape[0] for num_inner in range(1000): k = f'blocks.{num_inner}.0.weight' if k not in state: break else: raise ValueError(f'Unsupported MLP version: {version}') return input_dim, hidden_dim, output_dim, num_inner def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs): ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs) return ret def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs): state = strip_prefix(state, prefix) input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights) ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs) if spectral_weights: enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state) ret.load_state_dict(state) if spectral_weights: disable_spectral_reparam(ret) return ret ================================================ FILE: nit/models/nvidia_radio/radio/adaptor_registry.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace from typing import Dict, Any import torch from .adaptor_generic import GenericAdaptor, AdaptorBase dict_t = Dict[str, Any] state_t = Dict[str, torch.Tensor] class AdaptorRegistry: def __init__(self): self._registry = {} def register_adaptor(self, name): def decorator(factory_function): if name in self._registry: raise ValueError(f"Model '{name}' already registered") self._registry[name] = factory_function return factory_function return decorator def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: if name not in self._registry: return GenericAdaptor(main_config, adaptor_config, state) return self._registry[name](main_config, adaptor_config, state) # Creating an instance of the registry adaptor_registry = AdaptorRegistry() ================================================ FILE: nit/models/nvidia_radio/radio/block.py ================================================ # Ultralytics YOLO 🚀, AGPL-3.0 license """ Block modules """ import torch import torch.nn as nn from timm.models.layers import DropPath from .conv import Conv # from .transformer import TransformerBlock __all__ = ('C2f', 'Bottleneck',) class C2f(nn.Module): """Faster Implementation of CSP Bottleneck with 2 convolutions.""" def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() if drop_path is None: drop_path = [0.0] * n self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n)) def forward(self, x): """Forward pass through C2f layer.""" y = list(self.cv1(x).chunk(2, 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) def forward_split(self, x): """Forward pass using split() instead of chunk().""" y = list(self.cv1(x).split((self.c, self.c), 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) class Bottleneck(nn.Module): """Standard bottleneck.""" def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, k[0], 1) self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): """'forward()' applies the YOLOv5 FPN to input data.""" return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x)) ================================================ FILE: nit/models/nvidia_radio/radio/cls_token.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import Optional import torch from torch import nn class ClsToken(nn.Module): def __init__(self, ndim: int, num_tokens: int = 1, enabled: bool = True, register_multiple: Optional[int] = None, num_registers: Optional[int] = None, ): super().__init__() self.ndim = ndim self.enabled = enabled self.num_registers = 0 self.num_tokens = num_tokens if enabled: if num_registers: self.num_registers = num_registers elif register_multiple: self.num_registers = register_multiple - (num_tokens % register_multiple) scale = ndim ** -0.5 self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale) else: self.token = None self.num_patches = self.num_tokens + self.num_registers def disable(self): self.token = None self.enabled = False def forward(self, x: torch.Tensor): if self.token is None: return x token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) x = torch.cat([ token, x, ], dim=1) return x def no_weight_decay(self): return [ 'token', ] ================================================ FILE: nit/models/nvidia_radio/radio/common.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from dataclasses import dataclass from typing import Optional from .radio_model import Resolution @dataclass class RadioResource: url: str patch_size: int max_resolution: int preferred_resolution: Resolution vitdet_num_windowed: Optional[int] = None vitdet_num_global: Optional[int] = None RESOURCE_MAP = { # RADIOv2.5 "radio_v2.5-b": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=(768, 768), vitdet_num_global=4, ), "radio_v2.5-l": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=(768, 768), vitdet_num_global=4, ), "radio_v2.5-h": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=(768, 768), vitdet_num_global=4, ), "radio_v2.5-h-norm": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=(768, 768), vitdet_num_global=4, ), "radio_v2.5-g": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true", patch_size=14, max_resolution=1792, preferred_resolution=(896, 896), vitdet_num_global=8, ), # RADIO "radio_v2.1": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=Resolution(432, 432), vitdet_num_windowed=5, ), "radio_v2": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=Resolution(432, 432), vitdet_num_windowed=5, ), "radio_v1": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true", patch_size=14, max_resolution=1050, preferred_resolution=Resolution(378, 378), ), # E-RADIO "e-radio_v2": RadioResource( "https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=Resolution(512, 512), ), # C-RADIO "c-radio_v2.5-g": RadioResource( "https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar", patch_size=16, max_resolution=2048, preferred_resolution=(768, 768), vitdet_num_global=8, ), "c-radio_v3-l": RadioResource( # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L # and accept the license terms. "https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true", patch_size=16, max_resolution=2048, preferred_resolution=Resolution(512, 512), ), } DEFAULT_VERSION = "radio_v2.5-h" ================================================ FILE: nit/models/nvidia_radio/radio/conv.py ================================================ # Ultralytics YOLO 🚀, AGPL-3.0 license """ Convolution modules """ import math import numpy as np import torch import torch.nn as nn __all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv') def autopad(k, p=None, d=1): # kernel, padding, dilation """Pad to 'same' shape outputs.""" if d > 1: k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p # Pavlo's implementation with switch to deploy class Conv(nn.Module): default_act = nn.SiLU() # default activation def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True): super().__init__() self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False) if 1: self.bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self,x): x = self.conv(x) x = self.bn(x) x = self.act(x) return x @torch.no_grad() def switch_to_deploy(self): if not isinstance(self.bn, nn.Identity): # return 1 c, bn = self.conv, self.bn w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 # m = torch.nn.Conv2d(w.size(1) * c.groups, # w.size(0), # w.shape[2:], # stride=c.stride, # padding=c.padding, # dilation=c.dilation, # groups=c.groups) self.conv.weight.data.copy_(w) self.conv.bias = nn.Parameter(b) # self.conv.bias.data.copy_(b) # self.conv = m.to(c.weight.device) self.bn = nn.Identity() ================================================ FILE: nit/models/nvidia_radio/radio/dinov2_arch.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py # Nvidia # NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking, # but also because Huggingface does a string replace of `gamma` to something else when loading the model state, # and this breaks loading of this model. from enum import Enum from functools import partial import logging import math import os import sys from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import warnings import torch from torch import nn from torch.nn import functional as F from torch.nn.init import trunc_normal_ _torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention') XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind XFORMERS_AVAILABLE = True else: raise ImportError except ImportError: XFORMERS_AVAILABLE = False def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x def flops(self) -> float: Ho, Wo = self.patches_resolution flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) if self.norm is not None: flops += Ho * Wo * self.embed_dim return flops class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] if _torch_has_sdpa: x = F.scaled_dot_product_attention( q, k, v, is_causal=False, dropout_p=self.attn_drop.p if self.training else 0., scale=self.scale, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) if not XFORMERS_AVAILABLE: SwiGLU = SwiGLUFFN class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, torch.Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.grandma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.grandma) if self.inplace else x * self.grandma def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): # Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation # of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either # format key_a = f'{prefix}gamma' key_b = f'{prefix}grandma' if key_a in state_dict: gamma = state_dict[key_a] elif key_b in state_dict: gamma = state_dict[key_b] else: if strict: raise KeyError(f"Couldn't find the key {key_a} nor {key_b} in the state dict!") else: missing_keys.append(key_a) missing_keys.append(key_b) unexpected_keys.extend(state_dict.keys()) gamma = None if gamma is not None: self.grandma.data.copy_(gamma) # return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: torch.Tensor) -> torch.Tensor: def attn_residual_func(x: torch.Tensor) -> torch.Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x) x = x + ffn_residual_func(x) return x class NestedTensorBlock(Block): def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None, ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None, ) return x_list else: def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, torch.Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): if not XFORMERS_AVAILABLE: raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError def drop_add_residual_stochastic_depth( x: torch.Tensor, residual_func: Callable[[torch.Tensor], torch.Tensor], sample_drop_ratio: float = 0.0, ) -> torch.Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[torch.Tensor], residual_func: Callable[[torch.Tensor, Any], torch.Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> torch.Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) return outputs def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule if ffn_layer == "mlp": ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size M = int(math.sqrt(N)) # Recover the number of patches in each dimension assert N == M * M kwargs = {} if self.interpolate_offset: # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors sx = float(w0 + self.interpolate_offset) / M sy = float(h0 + self.interpolate_offset) / M kwargs["scale_factor"] = (sx, sy) else: # Simply specify an output size instead of a scale factor kwargs["size"] = (w0, h0) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, **kwargs, ) assert (w0, h0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat( ( x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:], ), dim=1, ) return x def forward_features_list(self, x_list, masks_list): x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] for blk in self.blocks: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) for blk in self.blocks: x = blk(x) x_norm = self.norm(x) return { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return self.head(ret["x_norm_clstoken"]) def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_base(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): """ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 """ model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1536, depth=40, num_heads=24, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model class Weights(Enum): LVD142M = "LVD142M" def _make_dinov2_model( *, arch_name: str = "vit_large", img_size: int = 518, patch_size: int = 14, init_values: float = 1.0, ffn_layer: str = "mlp", block_chunks: int = 0, num_register_tokens: int = 0, interpolate_antialias: bool = False, interpolate_offset: float = 0.1, weights: Union[Weights, str] = Weights.LVD142M, **kwargs, ): if isinstance(weights, str): try: weights = Weights[weights] except KeyError: raise AssertionError(f"Unsupported weights: {weights}") vit_kwargs = dict( img_size=img_size, patch_size=patch_size, init_values=init_values, ffn_layer=ffn_layer, block_chunks=block_chunks, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, ) vit_kwargs.update(**kwargs) model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs) return model def dinov2_vits14(**kwargs): """ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model(arch_name="vit_small", **kwargs) def dinov2_vitb14(**kwargs): """ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model(arch_name="vit_base", **kwargs) def dinov2_vitl14(**kwargs): """ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model(arch_name="vit_large", **kwargs) def dinov2_vitg14(**kwargs): """ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( arch_name="vit_giant2", ffn_layer="swiglufused", **kwargs, ) def dinov2_vits14_reg(**kwargs): """ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( arch_name="vit_small", num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) def dinov2_vitb14_reg(**kwargs): """ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( arch_name="vit_base", num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) def dinov2_vitl14_reg(**kwargs): """ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( arch_name="vit_large", num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) def dinov2_vitg14_reg(**kwargs): """ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. """ return _make_dinov2_model( arch_name="vit_giant2", ffn_layer="swiglufused", num_register_tokens=4, interpolate_antialias=True, interpolate_offset=0.0, **kwargs, ) ================================================ FILE: nit/models/nvidia_radio/radio/dual_hybrid_vit.py ================================================ from logging import getLogger from typing import Tuple import torch from torch import nn from torch.nn import functional as F from timm.models import register_model from timm.models import vision_transformer as tvit from timm.models import convnext as tconv from einops import rearrange from . import extra_timm_models as et class Fuser(nn.Module): def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True): super().__init__() self.gated = gated mid_dim = max(src_dim, tgt_dim) * 2 self.fwd = nn.Sequential( nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1), nn.GELU(), nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1), ) def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor: if src.ndim == 3: shape = tgt.shape[-2:] else: shape = src.shape[-2:] nd = shape[0] * shape[1] if src.ndim == 3: src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape) if tgt.ndim == 3: tgt_pre = tgt[:, :-nd] tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape) else: tgt_pre = None pred = self.fwd(src) if self.gated: g, pred = torch.chunk(pred, 2, dim=1) g = F.sigmoid(g) pred = g * pred tgt = tgt + pred if tgt_pre is not None: tgt = rearrange(tgt, 'b c h w -> b (h w) c') tgt = torch.cat([tgt_pre, tgt], dim=1) return tgt class AttnDownsample(nn.Module): def __init__(self, dim: int, window_size: int, num_heads: int = 16): super().__init__() self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01) self.kv = nn.Linear(dim, dim * 2) self.proj = nn.Linear(dim, dim) self.window_size = window_size self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor: ntok = twod_shape[0] * twod_shape[1] x_pre = x[:, :-ntok] B = x.shape[0] ds_hw = tuple(s // self.window_size for s in twod_shape) x_spat = rearrange( x[:, -ntok:], 'b (h d1 w d2) c -> (b h w) (d1 d2) c', h=ds_hw[0], w=ds_hw[1], d1=self.window_size, d2=self.window_size, ) B, N, C = x_spat.shape k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q = (self.q * self.scale).expand(B, -1, -1, -1) attn = q @ k.transpose(-2, -1) attn = F.softmax(attn, dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, C) x = self.proj(x) x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1]) x = torch.cat([x_pre, x], dim=1) return x class HybridModel(nn.Module): def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False, concatenate: bool = False, **kwargs): super().__init__() self.conv = conv self.vit = vit self.concatenate = concatenate conv.stages = nn.ModuleList(conv.stages) vit.blocks = nn.ModuleList(vit.blocks) self._half_vit_idx = len(vit.blocks) // 2 + 1 self._half_conv_idx = None x = torch.empty(1, 3, 256, 256) x = self.conv.stem(x) for i in range(len(conv.stages)): x = conv.stages[i](x) if self._half_conv_idx is None and x.shape[-2:] == (16, 16): self._half_conv_idx = i + 1 half_conv_dim = x.shape[1] final_conv_dim = x.shape[1] self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim) self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim) self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2) embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0) if not concatenate: self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False) self.final_block = tvit.Block(embed_dim, num_heads=16) self.embed_dim = embed_dim @property def patch_size(self): return 32 @property def no_fsdp_wrap_types(self): return {tvit.VisionTransformer, tconv.ConvNeXt} def forward(self, x: torch.Tensor) -> torch.Tensor: return self.forward_features(x) def forward_features(self, x: torch.Tensor) -> torch.Tensor: y_vit = self.vit.patch_generator(x) for i in range(self._half_vit_idx): y_vit = self.vit.blocks[i](y_vit) y_conv = self.conv.stem(x) for i in range(self._half_conv_idx): y_conv = self.conv.stages[i](y_conv) y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv) y_vit = self.vit_ds(y_vit, y_conv.shape[-2:]) for i in range(self._half_vit_idx, len(self.vit.blocks)): y_vit = self.vit.blocks[i](y_vit) for i in range(self._half_conv_idx, len(self.conv.stages)): y_conv = self.conv.stages[i](y_conv) if self.concatenate: y_conv = rearrange(y_conv, 'b c h w -> b (h w) c') # Average pool across the board, and replicate for each cls/register token conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1) y_conv = torch.cat([conv_summary, y_conv], dim=1) y = torch.cat([y_vit, y_conv], dim=2) else: y = self.final_fuse(y_conv, y_vit) y = self.final_block(y) summary = y[:, :self.vit.patch_generator.num_cls_tokens] features = y[:, self.vit.patch_generator.num_cls_patches:] return summary, features @register_model def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs): cfg = dict(num_classes=0, **kwargs) conv = tconv.convnextv2_base(pretrained=pretrained, **cfg) vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg) return HybridModel(vit, conv, pretrained, concatenate=concatenate) @register_model def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs): cfg = dict(num_classes=0, **kwargs) conv = tconv.convnextv2_large(pretrained=pretrained, **cfg) vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg) return HybridModel(vit, conv, pretrained, concatenate=concatenate) @register_model def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs): cfg = dict(num_classes=0, **kwargs) conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg) vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg) return HybridModel(vit, conv, pretrained, concatenate=concatenate) ================================================ FILE: nit/models/nvidia_radio/radio/enable_cpe_support.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import List, Optional, Set, Tuple, Union from types import MethodType import torch from torch import nn from timm.models import VisionTransformer, checkpoint_seq from timm.models.vision_transformer import Attention, Block from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer from .extra_models import DinoWrapper from .vit_patch_generator import ViTPatchGenerator from .forward_intermediates import forward_intermediates from .dual_hybrid_vit import HybridModel from flash_attn import flash_attn_varlen_func def _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: N, C = x.shape qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() x = flash_attn_varlen_func( q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(N, -1) x = self.proj(x) x = self.proj_drop(x) return x def _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_seqlens))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x def _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor]) -> torch.Tensor: device = images[0].device x = [] seqlens = [] for image in images: # image: [1, c, H, W] -> x: [n_cls+h*w, D], h=H/p and w=W/p _image = self.patch_generator(image).squeeze(0) x.append(_image) seqlens.append(_image.shape[0]) x = torch.cat(x, dim=0) seqlens = torch.tensor(seqlens, device=device, dtype=torch.int) cu_seqlens = torch.cat([ torch.tensor([0], device=device, dtype=torch.int32), torch.cumsum(seqlens, dim=0, dtype=torch.int32) ]) if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting(): for block in self.blocks: x = checkpoint_seq(block, x, cu_seqlens) else: for block in self.blocks: x = block(x, cu_seqlens) x = self.norm(x) return x, cu_seqlens def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor: x = self.patch_generator(x) if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def _take_indices( num_blocks: int, n: Optional[Union[int, List[int], Tuple[int]]], ) -> Tuple[Set[int], int]: if isinstance(n, int): assert n >= 0 take_indices = {x for x in range(num_blocks - n, num_blocks)} else: take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} return take_indices, max(take_indices) def _forward_intermediates_cpe( self, x: torch.Tensor, norm: bool = False, **kwargs, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: return forward_intermediates( self, patch_extractor=self.patch_generator, num_summary_tokens=self.patch_generator.num_skip, num_cls_tokens=self.patch_generator.num_cls_tokens, norm=self.norm if norm else lambda y: y, x=x, **kwargs, ) def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor: y = _forward_cpe(self.inner, x) return y[:, 0], y[:, self.num_summary_tokens:] def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs): return _forward_intermediates_cpe(self.inner, *args, **kwargs) def _enable_cpe_for_timm_vit(model: VisionTransformer, max_img_size: Union[int, Tuple[int, int]] = 1024, num_cls_tokens: int = 1, pos_dropout: float = 0.1, register_multiple: int = Optional[None], num_registers: int = Optional[None], support_packing: bool = False, ): if not isinstance(model, VisionTransformer): raise ValueError("CPE only support for VisionTransformer models!") patch_size = model.patch_embed.patch_size[0] embed_dim = model.embed_dim input_dims = model.patch_embed.img_size normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity) cls_token = model.cls_token is not None max_img_size = int(round(max_img_size / patch_size) * patch_size) patch_generator = ViTPatchGenerator( patch_size=patch_size, embed_dim=embed_dim, input_dims=input_dims, normalize_patches=normalize_patches, cls_token=cls_token, max_input_dims=max_img_size, pos_dropout=pos_dropout, num_cls_tokens=num_cls_tokens, register_multiple=register_multiple, num_registers=num_registers, ) model.patch_generator = patch_generator model.patch_embed = None model.cls_token = None model.pos_embed = None model.pos_drop = None model.patch_size = patch_size model.num_cls_tokens = num_cls_tokens model.num_registers = patch_generator.num_registers model.forward_features = MethodType(_forward_cpe, model) model.forward_intermediates = MethodType(_forward_intermediates_cpe, model) if support_packing: model.forward_features = MethodType(_forward_cpe_pack, model) for block in model.blocks: block.forward = MethodType(_block_forward_pack, block) block.attn.forward = MethodType(_attn_forward_pack, block.attn) def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper, max_img_size: Union[int, Tuple[int, int]] = 1024, num_cls_tokens: int = 1, pos_dropout: float = 0.1, register_multiple: int = Optional[None], num_registers: int = Optional[None], ): patch_size = model.patch_size embed_dim = model.embed_dim input_dims = model.inner.patch_embed.patches_resolution normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity) cls_token = True max_img_size = int(round(max_img_size / patch_size) * patch_size) patch_generator = ViTPatchGenerator( patch_size=patch_size, embed_dim=embed_dim, input_dims=input_dims, normalize_patches=normalize_patches, cls_token=cls_token, max_input_dims=max_img_size, pos_dropout=pos_dropout, num_cls_tokens=num_cls_tokens, register_multiple=register_multiple, num_registers=num_registers, patch_bias=True, ) inner = model.inner inner.patch_generator = patch_generator inner.patch_embed = None inner.cls_token = None inner.pos_embed = None inner.register_tokens = None inner.patch_size = patch_size model.forward_features = MethodType(_forward_cpe_dinov2, model) model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model) def enable_cpe(model: nn.Module, *args, **kwargs, ): if isinstance(model, VisionTransformer): _enable_cpe_for_timm_vit(model, *args, **kwargs) elif isinstance(model, DinoWrapper): _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs) elif isinstance(model, HybridModel): _enable_cpe_for_timm_vit(model.vit, *args, **kwargs) else: raise ValueError(f'CPE not supported for this model type: {type(model)}') ================================================ FILE: nit/models/nvidia_radio/radio/enable_damp.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from logging import getLogger import math import os from typing import Dict, List, Optional, Union, Tuple from types import MethodType import torch from torch import nn from torch.nn import functional as F from torch.nn.utils import parametrize # For now, don't do anything class DAMP(nn.Identity): def __init__(self, std: float): super().__init__() self.std = std def enable_damp(model: nn.Module, std: float): if isinstance(model, (list, tuple)): for m in model: enable_damp(m, std) return for name, module in model.named_modules(): if isinstance(module, nn.Linear): parametrize.register_parametrization(module, 'weight', DAMP(std)) def configure_damp_from_args(model: nn.Module, args): damp = getattr(args, 'damp', None) if damp: enable_damp(model, damp) ================================================ FILE: nit/models/nvidia_radio/radio/enable_spectral_reparam.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from logging import getLogger import math import os from typing import Dict, List, Optional, Union, Tuple from types import MethodType import torch from torch import nn from torch.nn import functional as F from torch.nn.utils import parametrize from torch.nn.utils.parametrizations import _SpectralNorm from timm.models.vision_transformer import Attention, Mlp _EPS = 1e-5 class _SNReweight(_SpectralNorm): def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs): super().__init__(weight, *args, **kwargs) self.alpha = alpha self.version = version self.register_buffer('_sn_version', torch.tensor(version)) if init_norm_to_current: # This will set the numerator to match the denominator, which should preserve the original values init_scale = self._get_sigma(weight, n_power_iterations=20).item() else: init_scale = 1.0 if version == 1: init_value = init_scale elif version == 2: t = init_scale - alpha if t < _EPS: getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.') t = _EPS init_value = math.log(math.exp(t) - 1) else: raise ValueError(f'Unsupported version: {version}') # Make 2D so that weight decay gets applied self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device)) # Re-implementing this because we need to make division by sigma safe def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor: if not n_power_iterations: n_power_iterations = self.n_power_iterations if weight.ndim == 1: # Faster and more exact path, no need to approximate anything sigma = weight.norm() else: weight_mat = self._reshape_weight_to_matrix(weight) if self.training: self._power_method(weight_mat, n_power_iterations) # See above on why we need to clone u = self._u.clone(memory_format=torch.contiguous_format) v = self._v.clone(memory_format=torch.contiguous_format) # The proper way of computing this should be through F.bilinear, but # it seems to have some efficiency issues: # https://github.com/pytorch/pytorch/issues/58093 sigma = torch.dot(u, torch.mv(weight_mat, v)) return sigma + self.eps def forward(self, weight: torch.Tensor, *args, **kwargs): dtype = weight.dtype sigma = self._get_sigma(weight, *args, **kwargs) if self.version == 1: scale = self.scale elif self.version == 2: scale = F.softplus(self.scale) + self.alpha else: raise ValueError(f'Unsupported version: {self.version}') scale = scale.float() / sigma.float() y = weight * scale if dtype in (torch.float16, torch.bfloat16): y = y.to(dtype) return y def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): version_key = f'{prefix}_sn_version' if version_key not in state_dict: self.version = 1 state_dict[version_key] = torch.tensor(1) return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class _ChunkedSNReweight(nn.Module): def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs): super().__init__() self.num_chunks = num_chunks parts = weight.split(weight.shape[0] // num_chunks, dim=0) self.parts = nn.ModuleList([ _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) for p in parts ]) def forward(self, weight: torch.Tensor, *args, **kwargs): parts = weight.split(weight.shape[0] // self.num_chunks, dim=0) parts = [ fn(p) for fn, p in zip(self.parts, parts) ] return torch.cat(parts, dim=0) class _AttnSNReweight(_ChunkedSNReweight): def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs): super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs) if not renorm_values: self.parts[2] = nn.Identity() def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]], n_power_iterations: int = 1, eps: float = 1e-6, init_norm_to_current: bool = False, renorm_values: bool = True, renorm_mlp: bool = True, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None): if isinstance(model, (list, tuple)): for i, sub in enumerate(model): sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps, init_norm_to_current=init_norm_to_current, renorm_values=renorm_values, renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd) return print('Enabling spectral reparametrization') args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current) visited_prefixes = set() def is_guidance_parametrized(name: str): if state_dict_guidance is None: return True p_name = f'{name}.parametrizations' is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version')) return is_prm def parametrize_linear(linear: nn.Linear): parametrize.register_parametrization( linear, 'weight', _SNReweight(linear.weight, **args) ) for name, mod in model.named_modules(): pref = '.'.join(name.split('.')[:-1]) if pref in visited_prefixes: continue if isinstance(mod, Attention) or name.endswith('.attn'): if is_guidance_parametrized(f'{name}.qkv'): parametrize.register_parametrization( mod.qkv, 'weight', _AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args), ) if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'): parametrize_linear(mod.proj) visited_prefixes.add(name) elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'): if is_guidance_parametrized(f'{name}.w12'): parametrize.register_parametrization( mod.w12, 'weight', _ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args), ) if is_guidance_parametrized(f'{name}.w3'): parametrize_linear(mod.w3) visited_prefixes.add(name) elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name): parametrize_linear(mod) def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None): spectral_reparam = getattr(args, 'spectral_reparam', False) if isinstance(spectral_reparam, bool) and spectral_reparam: enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance) elif isinstance(spectral_reparam, dict): enable_spectral_reparam( model, n_power_iterations=spectral_reparam.get('n_power_iterations', 1), eps=spectral_reparam.get('eps', 1e-12), init_norm_to_current=True, state_dict_guidance=state_dict_guidance, ) def disable_spectral_reparam(model: nn.Module): print('Disabling spectral reparametrization') for name, mod in model.named_modules(): if parametrize.is_parametrized(mod): parametrize.remove_parametrizations(mod, 'weight') pass if __name__ == '__main__': import argparse from . import radio_model as create_model parser = argparse.ArgumentParser(description='Remove parametrization from state dict') parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load') parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint') parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields') parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict') args = parser.parse_args() if not args.output: chk_dir, chk_name = os.path.split(args.checkpoint) args.output = os.path.join(chk_dir, f'clean_{chk_name}') print(f'Set output to "{args.output}"') chk = torch.load(args.checkpoint, map_location='cpu', mmap=True) model = create_model.create_model_from_args(chk['args']) key = 'base_model.' mod_state = dict() extra_state = dict() for k, v in chk['state_dict'].items(): if k.startswith(key): mod_state[k[len(key):]] = v else: extra_state[k] = v chk_load_info = model.load_state_dict(mod_state, strict=args.strict) if chk_load_info.unexpected_keys or chk_load_info.missing_keys: print(chk_load_info) if chk['args'].spectral_reparam: disable_spectral_reparam(model) if hasattr(chk['args'], 'dtype'): model.to(dtype=chk['args'].dtype) mod_state = model.state_dict() final_state = dict() final_state.update({f'{key}{k}': v for k, v in mod_state.items()}) final_state.update(extra_state) chk['state_dict'] = final_state chk['args'].spectral_reparam = False if args.release: chk = { 'arch': chk['arch'], 'epoch': chk['epoch'], 'state_dict': chk['state_dict'], 'args': chk['args'], } torch.save(chk, args.output) pass ================================================ FILE: nit/models/nvidia_radio/radio/eradio_model.py ================================================ #!/usr/bin/env python3 # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # E-RADIO model from # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023). # based on FasterViT, Swin Transformer, YOLOv8 # FasterViT: # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023). import timm import torch import torch.nn as nn from timm.models.registry import register_model from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d import numpy as np import torch.nn.functional as F import math import warnings ####################### ## Codebase from YOLOv8 ## BEGINNING ####################### class C2f(nn.Module): """Faster Implementation of CSP Bottleneck with 2 convolutions.""" """From YOLOv8 codebase""" def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() if drop_path is None: drop_path = [0.0] * n self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n)) def forward(self, x): """Forward pass through C2f layer.""" y = list(self.cv1(x).chunk(2, 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) def forward_split(self, x): """Forward pass using split() instead of chunk().""" y = list(self.cv1(x).split((self.c, self.c), 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) class Bottleneck(nn.Module): """Standard bottleneck.""" def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, k[0], 1) self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): """'forward()' applies the YOLOv5 FPN to input data.""" return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x)) class Conv(nn.Module): """Modified to support layer fusion""" default_act = nn.SiLU() # default activation def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True): super().__init__() self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False) if 1: self.bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self,x): x = self.conv(x) x = self.bn(x) x = self.act(x) return x @torch.no_grad() def switch_to_deploy(self): # return 1 if not isinstance(self.bn, nn.Identity): c, bn = self.conv, self.bn w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 self.conv.weight.data.copy_(w) self.conv.bias = nn.Parameter(b) self.bn = nn.Identity() def autopad(k, p=None, d=1): # kernel, padding, dilation """Pad to 'same' shape outputs.""" if d > 1: k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad return p ####################### ## Codebase from YOLOv8 ## END ####################### def pixel_unshuffle(data, factor=2): # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually B, C, H, W = data.shape return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor) class SwiGLU(nn.Module): # should be more advanced, but doesnt improve results so far def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x def window_partition(x, window_size): """ Function for partitioning image into windows and later do windowed attention Args: x: (B, C, H, W) window_size: window size Returns: windows - local window features (num_windows*B, window_size*window_size, C) (Hp, Wp) - the size of the padded image """ B, C, H, W = x.shape if window_size == 0 or (window_size==H and window_size==W): windows = x.flatten(2).transpose(1, 2) Hp, Wp = H, W else: pad_h = (window_size - H % window_size) % window_size pad_w = (window_size - W % window_size) % window_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect") Hp, Wp = H + pad_h, W + pad_w x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size) windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C) return windows, (Hp, Wp) class Conv2d_BN(nn.Module): ''' Conv2d + BN layer with folding capability to speed up inference Can be merged with Conv() function with additional arguments ''' def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False): super().__init__() self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False) if 1: self.bn = torch.nn.BatchNorm2d(b) torch.nn.init.constant_(self.bn.weight, bn_weight_init) torch.nn.init.constant_(self.bn.bias, 0) def forward(self,x): x = self.conv(x) x = self.bn(x) return x @torch.no_grad() def switch_to_deploy(self): if not isinstance(self.bn, nn.Identity): c, bn = self.conv, self.bn w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ (bn.running_var + bn.eps)**0.5 self.conv.weight.data.copy_(w) self.conv.bias = nn.Parameter(b) self.bn = nn.Identity() def window_reverse(windows, window_size, H, W, pad_hw): """ Windows to the full feature map Args: windows: local window features (num_windows*B, window_size, window_size, C) window_size: Window size H: Height of image W: Width of image pad_w - a tuple of image passing used in windowing step Returns: x: (B, C, H, W) """ # print(f"window_reverse, windows.shape {windows.shape}") Hp, Wp = pad_hw if window_size == 0 or (window_size==H and window_size==W): B = int(windows.shape[0] / (Hp * Wp / window_size / window_size)) x = windows.transpose(1, 2).view(B, -1, H, W) else: B = int(windows.shape[0] / (Hp * Wp / window_size / window_size)) x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp) if Hp > H or Wp > W: x = x[:, :, :H, :W, ].contiguous() return x class PosEmbMLPSwinv2D(nn.Module): """ 2D positional embedding from Swin Transformer v2 Added functionality to store the positional embedding in the model and not recompute it every time """ def __init__( self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512, ): super().__init__() self.window_size = window_size self.num_heads = num_heads # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( nn.Linear(2, cpb_mlp_hidden, bias=True), nn.ReLU(inplace=True), nn.Linear(cpb_mlp_hidden, num_heads, bias=False), ) self.grid_exists = False self.seq_length = seq_length self.deploy = False self.num_heads = num_heads self.no_log = no_log self.pretrained_window_size = pretrained_window_size self.relative_bias_window_size = window_size relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads, pretrained_window_size, seq_length, no_log) self.register_buffer("relative_coords_table", relative_coords_table) self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_bias", relative_bias) # for EMA def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log): # as in separate function to support window size chage after model weights loading relative_coords_h = torch.arange( -(window_size[0] - 1), window_size[0], dtype=torch.float32 ) relative_coords_w = torch.arange( -(window_size[1] - 1), window_size[1], dtype=torch.float32 ) relative_coords_table = ( torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) .permute(1, 2, 0) .contiguous() .unsqueeze(0) ) # 1, 2*Wh-1, 2*Ww-1, 2 if pretrained_window_size[0] > 0: relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 else: relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 if not no_log: relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = ( torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / np.log2(8) ) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = ( coords_flatten[:, :, None] - coords_flatten[:, None, :] ) # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute( 1, 2, 0 ).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_bias = torch.zeros(1, num_heads, seq_length, seq_length) self.relative_bias_window_size = window_size return relative_coords_table, relative_position_index, relative_bias def switch_to_deploy(self): self.deploy = True self.grid_exists = True def forward(self, input_tensor): # for efficiency, we want this forward to be folded into a single operation (sum) # if resolution stays the same, then we dont need to recompute MLP layers if not self.deploy or self.training: self.grid_exists = False #compare if all elements in self.window_size list match those in self.relative_bias_window_size if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]): relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads, self.pretrained_window_size, self.seq_length, self.no_log) self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device) self.relative_position_index = relative_position_index.to(self.relative_position_index.device) self.relative_bias = relative_bias.to(self.relative_bias.device) if self.deploy and self.grid_exists: input_tensor = input_tensor + self.relative_bias return input_tensor if 1: self.grid_exists = True relative_position_bias_table = self.cpb_mlp( self.relative_coords_table ).view(-1, self.num_heads) relative_position_bias = relative_position_bias_table[ self.relative_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1, ) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute( 2, 0, 1 ).contiguous() # nH, Wh*Ww, Wh*Ww relative_position_bias = 16 * torch.sigmoid(relative_position_bias) self.relative_bias = relative_position_bias.unsqueeze(0) input_tensor = input_tensor + self.relative_bias return input_tensor class GRAAttentionBlock(nn.Module): def __init__(self, window_size, dim_in, dim_out, num_heads, drop_path=0., qk_scale=None, qkv_bias=False, norm_layer=nn.LayerNorm, layer_scale=None, use_swiglu=True, subsample_ratio=1, dim_ratio=1, conv_base=False, do_windowing=True, multi_query=False, use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0): ''' Global Resolution Attention Block , see README for details Attention with subsampling to get a bigger receptive field for attention conv_base - use conv2d instead of avgpool2d for downsample / upsample ''' super().__init__() self.shift_size=window_size//2 if use_shift else 0 self.do_windowing = do_windowing self.subsample_ratio = subsample_ratio if do_windowing: if conv_base: self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity() self.downsample_mixer = nn.Identity() self.upsample_mixer = nn.Identity() self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity() else: self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity() self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity() self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity() self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity() # in case there is no downsampling conv we want to have it separately # will help with information propagation between windows if subsample_ratio == 1: # conv_groups_ratio=0 self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False) # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False) # self.pre_conv_act = nn.ReLU6() #for simplicity: self.pre_conv_act = nn.Identity() if conv_groups_ratio == -1: self.pre_conv = nn.Identity() self.pre_conv_act = nn.Identity() self.window_size = window_size self.norm1 = norm_layer(dim_in) self.attn = WindowAttention( dim_in, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, resolution=window_size, seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query, shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float] self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1 ### mlp layer mlp_ratio = 4 self.norm2 = norm_layer(dim_in) mlp_hidden_dim = int(dim_in * mlp_ratio) activation = nn.GELU if not use_swiglu else SwiGLU mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu) self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1 self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): skip_connection = x attn_mask = None # in case there is no downsampling conv we want to have it separately # will help with information propagation if self.subsample_ratio == 1: x = self.pre_conv_act(self.pre_conv(x)) + skip_connection if self.do_windowing: # performing windowing if required x = self.downsample_op(x) x = self.downsample_mixer(x) if self.window_size>0: H, W = x.shape[2], x.shape[3] if self.shift_size > 0 and H>self.window_size and W>self.window_size: # @swin like cyclic shift, doesnt show better performance x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) x, pad_hw = window_partition(x, self.window_size) if self.shift_size > 0 and H>self.window_size and W>self.window_size: # set atten matrix to have -100 and the top right square # attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0 # calculate attention mask for SW-MSA # not used in final version, can be useful for some cases especially for high res H, W = pad_hw img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 img_mask = img_mask.transpose(1,2).transpose(1,3) mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) # window attention x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W # mlp layer x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x))) if self.do_windowing: if self.window_size > 0: x = window_reverse(x, self.window_size, H, W, pad_hw) # reverse cyclic shift if self.shift_size > 0 and H>self.window_size and W>self.window_size: # @swin like cyclic shift, not tested x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3)) x = self.upsample_mixer(x) x = self.upsample_op(x) if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]: x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode="reflect") # need to add skip connection because downsampling and upsampling will break residual connection # 0.5 is needed to make sure that the skip connection is not too strong # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection x = 0.5 * x + 0.5 * skip_connection return x class MultiResolutionAttention(nn.Module): """ MultiResolutionAttention (MRA) module The idea is to use multiple attention blocks with different resolution Feature maps are downsampled / upsampled for each attention block on different blocks Every attention block supports windowing """ def __init__(self, window_size, sr_ratio, dim, dim_ratio, num_heads, do_windowing=True, layer_scale=1e-5, norm_layer=nn.LayerNorm, drop_path = 0, qkv_bias=False, qk_scale=1.0, use_swiglu=True, multi_query=False, conv_base=False, use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None: """ Args: input_resolution: input image resolution window_size: window size compression_ratio: compression ratio max_depth: maximum depth of the GRA module use_shift: do window shifting """ super().__init__() depth = len(sr_ratio) self.attention_blocks = nn.ModuleList() for i in range(depth): subsample_ratio = sr_ratio[i] if len(window_size) > i: window_size_local = window_size[i] else: window_size_local = window_size[0] self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local, dim_in=dim, dim_out=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, layer_scale=layer_scale, drop_path=drop_path, use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio, do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base, use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio), ) def forward(self, x): for attention_block in self.attention_blocks: x = attention_block(x) return x class Mlp(nn.Module): """ Multi-Layer Perceptron (MLP) block """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, use_swiglu=True, drop=0.): """ Args: in_features: input features dimension. hidden_features: hidden features dimension. out_features: output features dimension. act_layer: activation function. drop: dropout rate. """ super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=False) def forward(self, x): x_size = x.size() x = x.view(-1, x_size[-1]) x = self.fc1(x) x = self.act(x) x = self.fc2(x) x = x.view(x_size) return x class Downsample(nn.Module): """ Down-sampling block Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time """ def __init__(self, dim, shuffle = False, ): """ Args: dim: feature size dimension. shuffle: idea with keep_dim: bool argument for maintaining the resolution. """ super().__init__() dim_out = 2 * dim if shuffle: self.norm = lambda x: pixel_unshuffle(x, factor=2) self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False) # pixel unshuffleging works well but doesnt provide any speedup else: # removed layer norm for better, in this formulation we are getting 10% better speed # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension # therefore we remove it compared to the original implementation in FasterViT self.norm = nn.Identity() self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False) def forward(self, x): x = self.norm(x) x = self.reduction(x) return x class PatchEmbed(nn.Module): """ Patch embedding block Used to convert image into an initial set of feature maps with lower resolution """ def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False): """ Args: in_chans: number of input channels. in_dim: intermediate feature size dimension to speed up stem. dim: final stem channel number shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field """ super().__init__() # shuffle_down = False if not shuffle_down: self.proj = nn.Identity() self.conv_down = nn.Sequential( Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False), nn.ReLU(), Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False), nn.ReLU() ) else: self.proj = lambda x: pixel_unshuffle(x, factor=4) self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1), nn.ReLU(), ) def forward(self, x): x = self.proj(x) x = self.conv_down(x) return x class ConvBlock(nn.Module): """ Convolutional block, used in first couple of stages Experimented with plan resnet-18 like modules, they are the best in terms of throughput Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end) """ def __init__(self, dim, drop_path=0., layer_scale=None, kernel_size=3, ): super().__init__() self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1) self.act1 = nn.GELU() self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1) self.layer_scale = layer_scale if layer_scale is not None and type(layer_scale) in [int, float]: self.gamma = nn.Parameter(layer_scale * torch.ones(dim)) self.layer_scale = True else: self.layer_scale = False self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input = x x = self.conv1(x) x = self.act1(x) x = self.conv2(x) if self.layer_scale: x = x * self.gamma.view(1, -1, 1, 1) x = input + self.drop_path(x) return x class WindowAttention(nn.Module): # Windowed Attention from SwinV2 # use a MLP trick to deal with various input image resolutions, then fold it to improve speed def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0, seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512): # taken from EdgeViT and tweaked with attention bias. super().__init__() if not dim_out: dim_out = dim self.shift_size = shift_size self.multi_query = multi_query self.num_heads = num_heads head_dim = dim // num_heads self.head_dim = dim // num_heads self.dim_internal = dim self.scale = qk_scale or head_dim ** -0.5 if not multi_query: self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) else: self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias) self.proj = nn.Linear(dim, dim_out, bias=False) # attention positional bias self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution], pretrained_window_size=[resolution, resolution], num_heads=num_heads, seq_length=seq_length, cpb_mlp_hidden=cpb_mlp_hidden) self.resolution = resolution def forward(self, x, attn_mask = None): B, N, C = x.shape if not self.multi_query: qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] else: qkv = self.qkv(x) (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2) q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3) v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3) attn = (q @ k.transpose(-2, -1)) * self.scale attn = self.pos_emb_funct(attn) #add window shift if attn_mask is not None: nW = attn_mask.shape[0] attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, C) x = self.proj(x) return x class ERADIOLayer(nn.Module): """ E-RADIO Layer """ def __init__(self, dim, depth, num_heads, window_size, conv=False, downsample=True, mlp_ratio=4., qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, drop_path=0., layer_scale=None, layer_scale_conv=None, sr_dim_ratio=1, sr_ratio=1, multi_query=False, use_swiglu=True, yolo_arch=False, downsample_shuffle=False, conv_base=False, use_shift=False, cpb_mlp_hidden=512, conv_groups_ratio=0, verbose: bool = True, ): """ Args: dim: feature size dimension. depth: number of layers in each stage. input_resolution: input image resolution. window_size: window size in each stage. downsample: bool argument for down-sampling. mlp_ratio: MLP ratio. num_heads: number of heads in each stage. qkv_bias: bool argument for query, key, value learnable bias. qk_scale: bool argument to scaling query, key. drop: dropout rate. attn_drop: attention dropout rate. drop_path: drop path rate. norm_layer: normalization layer. layer_scale: layer scaling coefficient. use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution) conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention """ super().__init__() self.conv = conv self.yolo_arch=False self.verbose = verbose if conv: if not yolo_arch: self.blocks = nn.ModuleList([ ConvBlock(dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, layer_scale=layer_scale_conv) for i in range(depth)]) self.blocks = nn.Sequential(*self.blocks) else: self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5) self.yolo_arch=True else: if not isinstance(window_size, list): window_size = [window_size] self.window_size = window_size[0] self.do_single_windowing = True if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio] self.sr_ratio = sr_ratio if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1: self.do_single_windowing = False do_windowing = True else: self.do_single_windowing = True do_windowing = False #for v2_2 if conv_groups_ratio != -1: self.do_single_windowing = False do_windowing = True self.blocks = nn.ModuleList() for i in range(depth): self.blocks.append( MultiResolutionAttention(window_size=window_size, sr_ratio=sr_ratio, dim=dim, dim_ratio = sr_dim_ratio, num_heads=num_heads, norm_layer=norm_layer, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, layer_scale=layer_scale, qkv_bias=qkv_bias, qk_scale=qk_scale, use_swiglu=use_swiglu, do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base, cpb_mlp_hidden=cpb_mlp_hidden, use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True , conv_groups_ratio=conv_groups_ratio, )) self.blocks = nn.Sequential(*self.blocks) self.transformer = not conv self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle) def forward(self, x): B, C, H, W = x.shape # do padding for transforemr interpolate = True if self.transformer and interpolate: # Windowed Attention will split feature map into windows with the size of window_size x window_size # if the resolution is not divisible by window_size, we need to interpolate the feature map # can be done via padding, but doing so after training hurts the model performance. # interpolation affects the performance as well, but not as much as padding if isinstance(self.window_size, list) or isinstance(self.window_size, tuple): current_max_window_size = max(self.window_size) else: current_max_window_size = self.window_size max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio]) if H % max_window_size != 0 or W % max_window_size != 0: new_h = int(np.ceil(H/max_window_size)*max_window_size) new_w = int(np.ceil(W/max_window_size)*max_window_size) x = F.interpolate(x, size=(new_h, new_w), mode='nearest') if self.verbose: warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.") if self.transformer and self.do_single_windowing: H, W = x.shape[2], x.shape[3] x, pad_hw = window_partition(x, self.window_size) #run main blocks x = self.blocks(x) if self.transformer and self.do_single_windowing: x = window_reverse(x, self.window_size, H, W, pad_hw) if self.transformer and interpolate: #lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution. x = F.interpolate(x, size=(H, W), mode='nearest') if self.downsample is None: return x, x return self.downsample(x), x # changing to output pre downsampled features class InterpolateLayer(nn.Module): def __init__(self, size=None, scale_factor=None, mode='nearest'): super(InterpolateLayer, self).__init__() self.size = size self.scale_factor = scale_factor self.mode = mode def forward(self, x): return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode) class HiResNeck(nn.Module): """ The block is used to output dense features from all stages Otherwise, by default, only the last stage features are returned with E-RADIO """ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled): ''' Hi Resolution neck to support output of high res features that are useful for dense tasks. depths - total number of layers in the base model neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc. earlier layers result in higher resolution features at the cost of compute full_features_head_dim - number of channels in the dense features head ''' super().__init__() # create feature projection layers for segmentation output self.neck_features_proj = nn.ModuleList() self.neck_start_stage = neck_start_stage upsample_ratio = 1 for i in range(len(depths)): level_n_features_output = int(dim * 2 ** i) if self.neck_start_stage > i: continue if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output: feature_projection = nn.Sequential() if False: feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output, full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio)) else: # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio # print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output) feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest')) feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output)) feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) # B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio feature_projection.add_module("conv2", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0)) else: feature_projection = nn.Sequential() self.neck_features_proj.append(feature_projection) if i>0 and downsample_enabled[i]: upsample_ratio *= 2 def forward(self, x, il_level=-1, full_features=None): if self.neck_start_stage > il_level: return full_features if full_features is None: full_features = self.neck_features_proj[il_level - self.neck_start_stage](x) else: #upsample torch tensor x to match full_features size, and add to full_features feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x) if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]: feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2])) full_features = full_features + feature_projection return full_features class ERADIO(nn.Module): """ Efficient RADIO """ def __init__(self, dim, in_dim, depths, window_size, mlp_ratio, num_heads, drop_path_rate=0.2, in_chans=3, num_classes=1000, qkv_bias=False, qk_scale=None, layer_scale=None, layer_scale_conv=None, layer_norm_last=False, sr_ratio = [1, 1, 1, 1], max_depth = -1, conv_base=False, use_swiglu=False, multi_query=False, norm_layer=nn.LayerNorm, drop_uniform=False, yolo_arch=False, shuffle_down=False, downsample_shuffle=False, return_full_features=False, full_features_head_dim=128, neck_start_stage=1, use_neck=False, use_shift=False, cpb_mlp_hidden=512, conv_groups_ratio=0, verbose: bool = False, **kwargs): """ Args: dim: feature size dimension. depths: number of layers in each stage. window_size: window size in each stage. mlp_ratio: MLP ratio. num_heads: number of heads in each stage. drop_path_rate: drop path rate. in_chans: number of input channels. num_classes: number of classes. qkv_bias: bool argument for query, key, value learnable bias. qk_scale: bool argument to scaling query, key. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. norm_layer: normalization layer. layer_scale: layer scaling coefficient. return_full_features: output dense features as well as logits full_features_head_dim: number of channels in the dense features head neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0 for 224 resolution, the output of the stage before downsample: stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7 use_neck: even for summarization embedding use neck use_shift: SWIN like window shifting but without masking attention conv_groups_ratio: will be used for conv blocks where there is no multires attention, if 0 then normal conv, if 1 then channels are independent, if -1 then no conv at all """ super().__init__() num_features = int(dim * 2 ** (len(depths) - 1)) self.num_classes = num_classes self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down) # set return_full_features true if we want to return full features from all stages self.return_full_features = return_full_features self.use_neck = use_neck dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] if drop_uniform: dpr = [drop_path_rate for x in range(sum(depths))] if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths) self.levels = nn.ModuleList() for i in range(len(depths)): conv = True if (i == 0 or i == 1) else False level = ERADIOLayer(dim=int(dim * 2 ** i), depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, conv=conv, drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], downsample=(i < len(depths) - 1), layer_scale=layer_scale, layer_scale_conv=layer_scale_conv, sr_ratio=sr_ratio[i], use_swiglu=use_swiglu, multi_query=multi_query, norm_layer=norm_layer, yolo_arch=yolo_arch, downsample_shuffle=downsample_shuffle, conv_base=conv_base, cpb_mlp_hidden=cpb_mlp_hidden, use_shift=use_shift, conv_groups_ratio=conv_groups_ratio, verbose=verbose) self.levels.append(level) if self.return_full_features or self.use_neck: #num_heads downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))] self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled) self.switched_to_deploy = False self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features) self.avgpool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, LayerNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) @torch.jit.ignore def no_weight_decay_keywords(self): return {'rpb'} def forward_features(self, x): _, _, H, W = x.shape if H % 32 != 0 or W % 32 != 0: raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}") x = self.patch_embed(x) full_features = None for il, level in enumerate(self.levels): x, pre_downsample_x = level(x) if self.return_full_features or self.use_neck: full_features = self.high_res_neck(pre_downsample_x, il, full_features) # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x) x = self.norm(x) # new version for if not self.return_full_features: return x, None return x, full_features def forward(self, x): x, full_features = self.forward_features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.head(x) if full_features is not None: return x, full_features return x def switch_to_deploy(self): ''' A method to perform model self-compression merges BN into conv layers converts MLP relative positional bias into precomputed buffers ''' if not self.switched_to_deploy: for level in [self.patch_embed, self.levels, self.head]: for module in level.modules(): if hasattr(module, 'switch_to_deploy'): module.switch_to_deploy() self.switched_to_deploy = True def change_window_size(self, new_window_size): """ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter, especially in cases of uneven partitioning of the feature maps. E-RADIO allows for the adjustment of the window size after training, making it adaptable to different input image resolutions. The recommended values for window size based on input resolution are as follows: Input Resolution | Window Size 224 | 7 256 | 8 386 | 12 512 | 16 Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be img_res/16/2 for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size. Manual way to change resolution -> model.change_window_size(resolution) """ window_size = new_window_size print(f"Setting window size to {window_size}") for module in self.modules(): if hasattr(module, "window_size"): # check if tuple or a number if isinstance(module.window_size, tuple): if module.window_size[0] != window_size: module.window_size = (window_size, window_size) elif isinstance(module.window_size, list): if module.window_size[0] != window_size: module.window_size = [window_size, window_size] else: module.window_size = window_size def set_optimal_window_size(self, image_dim, max_window_size = 16): """ Using hand picked window size for various resolutions. E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter, especially in cases of uneven partitioning of the feature maps. E-RADIO allows for the adjustment of the window size after training, making it adaptable to different input image resolutions. The recommended values for window size based on input resolution are as follows: Input Resolution | Window Size 224 | 7 256 | 8 386 | 12 512 | 16 Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be img_res/16/2 for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size. Manual way to change resolution -> model.change_window_size(resolution) """ # import math def divisorGenerator(n): large_divisors = [] for i in range(1, int(math.sqrt(n) + 1)): if n % i == 0: yield i if i*i != n: large_divisors.append(n / i) for divisor in reversed(large_divisors): yield divisor if isinstance(image_dim, list) or isinstance(image_dim, tuple): image_dim = min(image_dim) # we do windowed attention in the 3rd stage for the first time, therefore //16, # we do subsampled attention with downsample by 2 so need to get //32 actually # ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc all_divisors = np.array(list(divisorGenerator(image_dim//32))) new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size)) # for image_dim in [128, 224, 256, 384, 512, 768, 1024]: # all_divisors = np.array(list(divisorGenerator(image_dim//32))) # new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size)) # print(f"Setting window size to {new_window_size} for image resolution {image_dim}") self.change_window_size(new_window_size = new_window_size) @register_model def eradio_large_fullres_ws16(pretrained=False, **kwargs): model = ERADIO( depths=[3, 3, 5, 5], num_heads=[2, 4, 8, 16], window_size=[None, None, [16, 16], 16], dim=192, in_dim=64, mlp_ratio=4, drop_path_rate=0.0, sr_ratio=[1, 1, [2, 1], 1], use_swiglu=False, yolo_arch=True, shuffle_down=False, conv_base=True, use_neck=True, full_features_head_dim=1536, neck_start_stage=2, **kwargs, ) if pretrained: model.load_state_dict(torch.load(pretrained)["state_dict"]) return model @register_model def eradio_xxxtiny(pretrained=False, **kwargs): # , model = ERADIO( depths=[1, 3, 4, 5], num_heads=[2, 4, 8, 16], window_size=[None, None, [16, 16], 16], dim=32, in_dim=32, mlp_ratio=4, drop_path_rate=0.0, sr_ratio=[1, 1, [2, 1], 1], use_swiglu=False, yolo_arch=True, shuffle_down=False, conv_base=True, use_neck=True, full_features_head_dim=256, neck_start_stage=2, **kwargs, ) if pretrained: model.load_state_dict(torch.load(pretrained)) return model @register_model def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs): model = ERADIO(depths=[1, 3, 4, 5], num_heads=[2, 4, 8, 16], window_size=[None, None, [12, 12], 12], dim=32, in_dim=32, mlp_ratio=4, drop_path_rate=0.0, sr_ratio=[1, 1, [2, 1], 1], use_swiglu=False, downsample_shuffle=False, yolo_arch=True, shuffle_down=False, cpb_mlp_hidden=64, use_neck=True, full_features_head_dim=256, neck_start_stage=2, conv_groups_ratio = 1, **kwargs) if pretrained: model.load_state_dict(torch.load(pretrained)["state_dict"]) return model @register_model def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs): model = ERADIO(depths=[1, 3, 4, 5], num_heads=[2, 4, 8, 16], window_size=[None, None, [16, 16], 16], dim=32, in_dim=32, mlp_ratio=4, drop_path_rate=0.0, sr_ratio=[1, 1, [2, 1], 1], use_swiglu=False, downsample_shuffle=False, yolo_arch=True, shuffle_down=False, cpb_mlp_hidden=64, use_neck=True, full_features_head_dim=256, neck_start_stage=1, conv_groups_ratio = 1, **kwargs) if pretrained: model.load_state_dict(torch.load(pretrained)["state_dict"]) return model @register_model def eradio(pretrained=False, **kwargs): return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs) ================================================ FILE: nit/models/nvidia_radio/radio/extra_models.py ================================================ from distutils.version import LooseVersion from types import MethodType from typing import List, Optional, Tuple, Union import warnings import torch from torch import nn import torch.nn.functional as F from timm.models.registry import register_model from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .forward_intermediates import forward_intermediates from .input_conditioner import InputConditioner _has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention') class PaliGemmaWrapper(nn.Module): def __init__(self, vis_model: nn.Module, embed_dim: int): super().__init__() self.vis_model = vis_model self.embed_dim = embed_dim @property def patch_size(self): return self.vis_model.embeddings.patch_size @property def blocks(self): return self.vis_model.encoder.layers @property def embed_dim(self): return self.vis_model.embeddings.embed_dim def forward(self, x: torch.Tensor): outputs = self.vis_model( x, return_dict=False, interpolate_pos_encoding=True, ) features = outputs[0].to(torch.float32) summary = features.mean(dim=1) return summary, features def forward_features(self, x: torch.Tensor): return self(x) def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16): from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version if LooseVersion(tx_version) > LooseVersion('4.44.2'): warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.') extra_args = dict() if dtype is not None: extra_args['torch_dtype'] = dtype rev = str(dtype).split('.')[-1] extra_args['revision'] = rev model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args) vis_model = model.vision_tower.vision_model vis_model = PaliGemmaWrapper(vis_model, embed_dim) return vis_model @register_model def paligemma_896_student(**kwargs): model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None) return model def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] x = F.scaled_dot_product_attention( q, k, v, is_causal=False, dropout_p=self.attn_drop.p if self.training else 0., scale=self.scale, ) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs): if cache_dir: torch.hub.set_dir(cache_dir) model: nn.Module = torch.hub.load( 'facebookresearch/dinov2', dino_v2_model, pretrained=pretrained, # **kwargs, ) if _has_torch_sdpa: for n, m in model.named_modules(): if n.endswith('.attn'): m.forward = MethodType(dv2_sdpa, m) return model class DinoWrapper(nn.Module): def __init__(self, dino_model: nn.Module): super().__init__() self.inner = dino_model dino_model.blocks = nn.Sequential(*dino_model.blocks) @property def embed_dim(self): return self.inner.embed_dim @property def patch_size(self): return self.inner.patch_size @property def num_cls_tokens(self): return getattr(self.inner, 'num_tokens', 1) @property def num_registers(self): return getattr(self.inner, 'num_register_tokens', 0) @property def num_summary_tokens(self): return self.num_cls_tokens + self.num_registers @property def blocks(self): return self.inner.blocks def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: parts = self.inner.forward_features(*args, **kwargs) cls_token = parts['x_norm_clstoken'] features = parts['x_norm_patchtokens'] return cls_token, features def forward_features(self, x: torch.Tensor): x = self.inner.prepare_tokens_with_masks(x) x = self.inner.blocks(x) x_norm = self.inner.norm(x) return x_norm[:, 0], x_norm[:, self.num_summary_tokens:] def patchify(self, x: torch.Tensor) -> torch.Tensor: return self.inner.prepare_tokens_with_masks(x) def forward_intermediates(self, x: torch.Tensor, norm: bool = False, **kwargs, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: return forward_intermediates( self, patch_extractor=self.inner.prepare_tokens_with_masks, num_summary_tokens=self.num_summary_tokens, num_cls_tokens=self.num_cls_tokens, norm=self.inner.norm if norm else lambda y: y, x=x, **kwargs, ) def _dino_student(arch: str, **kwargs): from . import dinov2_arch factory = getattr(dinov2_arch, arch) model = factory() model = DinoWrapper(model) conditioner = InputConditioner( input_scale=1.0, norm_mean=IMAGENET_DEFAULT_MEAN, norm_std=IMAGENET_DEFAULT_STD, ) model.input_conditioner = conditioner return model @register_model def dino_v2_l_student(**kwargs): return _dino_student('dinov2_vitl14_reg', **kwargs) @register_model def dino_v2_g_student(**kwargs): return _dino_student('dinov2_vitg14_reg', **kwargs) ================================================ FILE: nit/models/nvidia_radio/radio/extra_timm_models.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import math import warnings import torch from torch import nn from torch.nn import functional as F from timm.models import register_model from timm.models.vision_transformer import ( VisionTransformer, _create_vision_transformer as _timm_create_vision_transformer, Mlp, Block, LayerScale as TIMMLayerScale, ) # Import these to also register them from . import dinov2_arch @register_model def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Tiny (Vit-Ti/16) """ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3) model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/16) """ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14 ) model = _create_vision_transformer( 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ name = 'vit_large_patch14_reg4_dinov2' model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14 ) model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16) if pretrained: # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs)) else: model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs)) return model @register_model def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). """ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs) for m in model.modules(): if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm): m.norm = nn.LayerNorm(m.fc1.out_features) return model @register_model def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer: """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929). """ model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24) model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs)) if scaled_ln: _apply_scaled_ln(model) return model @register_model def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6) model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs)) return model def _create_vision_transformer(*args, **kwargs): model = _timm_create_vision_transformer(*args, **kwargs) _patch_layer_scale(model) return model def _patch_layer_scale(model: VisionTransformer): def replace_ls(old_ls: TIMMLayerScale): new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace) new_ls.load_state_dict(old_ls.state_dict()) return new_ls # Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name # other than gamma, so that HFHub doesn't mess with it! for mod in model.modules(): if isinstance(mod, Block): if isinstance(mod.ls1, TIMMLayerScale): mod.ls1 = replace_ls(mod.ls1) if isinstance(mod.ls2, TIMMLayerScale): mod.ls2 = replace_ls(mod.ls2) pass class ScaledLayerNorm(nn.LayerNorm): ''' https://arxiv.org/pdf/2502.05795v1 ''' def __init__(self, ln_base: nn.LayerNorm, depth: int = 0): super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine) self.load_state_dict(ln_base.state_dict()) self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False) def forward(self, x): y = super().forward(x) y = y * self.ln_scale return y class DyT(nn.Module): def __init__(self, C: int, init_alpha: float): super().__init__() self.alpha = nn.Parameter(torch.full((1,), init_alpha)) self.gamma = nn.Parameter(torch.ones(C)) self.beta = nn.Parameter(torch.zeros(C)) def forward(self, x: torch.Tensor): x = F.tanh(self.alpha * x) return self.gamma * x + self.beta @register_model def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int): return DyT(ln.normalized_shape[0], init_alpha=0.9) _replace_ln(model, _replace_ln_with_dyt) return model def _apply_scaled_ln(model: VisionTransformer): warnings.warn('Post-LayerNorm scaling activated!') _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth)) def _replace_ln(model: VisionTransformer, fn): def _inner_replace_ln(block: Block, depth: int, key: str): prev = getattr(block, key) if isinstance(prev, nn.LayerNorm): setattr(block, key, fn(prev, depth=depth)) for i, block in enumerate(model.blocks): _inner_replace_ln(block, i + 1, 'norm1') _inner_replace_ln(block, i + 1, 'norm2') ================================================ FILE: nit/models/nvidia_radio/radio/feature_normalizer.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from collections import namedtuple from typing import NamedTuple, Optional, Tuple import torch from torch import nn def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor): if x.ndim <= 3: x = x - mean x = x @ tx.T elif x.ndim == 4: x = x - mean.reshape(1, -1, 1, 1) kernel = tx.reshape(*tx.shape, 1, 1) x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0) else: raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}') return x class FeatureNormalizer(nn.Module): def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32): super().__init__() self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype)) self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: x = _run_kernel(x, self.mean, self.tx) return x class InterFeatState(NamedTuple): y: torch.Tensor alpha: torch.Tensor class IntermediateFeatureNormalizerBase(nn.Module): def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: raise NotImplementedError() class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32): super().__init__() self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype)) rot = torch.eye(embed_dim, dtype=dtype) if rot_per_layer: rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1) self.register_buffer('rotation', rot.contiguous()) self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype)) def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: if rot_index is None: rot_index = index if skip: assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.' prefix, x = x[:, :skip], x[:, skip:] rotation = self._get_rotation(rot_index) y = _run_kernel(x, self.means[index], rotation) alpha = self.alphas[index] if skip: alpha = torch.cat([ torch.ones(skip, dtype=alpha.dtype, device=alpha.device), alpha[None].expand(y.shape[1]), ]).reshape(1, -1, 1) y = torch.cat([prefix, y], dim=1) else: if x.ndim == 3: alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1) elif x.ndim == 4: alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:]) else: raise ValueError(f'Unsupported input dimension: {x.ndim}') return InterFeatState(y, alpha) def _get_rotation(self, rot_index: int) -> torch.Tensor: if self.rotation.ndim == 2: return self.rotation return self.rotation[rot_index] class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase): instances = dict() def __init__(self, dtype: torch.dtype, device: torch.device): super().__init__() self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device)) @staticmethod def get_instance(dtype: torch.dtype, device: torch.device): instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None) if instance is None: instance = NullIntermediateFeatureNormalizer(dtype, device) NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance return instance def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState: return InterFeatState(x, self.alpha) ================================================ FILE: nit/models/nvidia_radio/radio/forward_intermediates.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable from types import MethodType import torch from torch import nn from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer def _take_indices( num_blocks: int, n: Optional[Union[int, List[int], Tuple[int]]], ) -> Tuple[Set[int], int]: if isinstance(n, int): assert n >= 0 take_indices = {x for x in range(num_blocks - n, num_blocks)} else: take_indices = {num_blocks + idx if idx < 0 else idx for idx in n} return take_indices, max(take_indices) def forward_intermediates( model: nn.Module, patch_extractor: Callable[[torch.Tensor], torch.Tensor], norm: nn.Module, num_summary_tokens: int, num_cls_tokens: int, x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, aggregation: Optional[str] = "sparse", inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None, norm_alpha_scheme = "post-alpha", block_kwargs: Dict = None, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """ Forward features that returns intermediates. The Dense layer aggregation method is inspired from the paper: "Dense Connector for MLLMs" by Yao, Huanjin et al. (2024). arXiv preprint arXiv:2405.13800} Args: x: Input image tensor indices: Take last n blocks if int, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features aggregation: intermediate layer aggregation method (sparse or dense) norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha") Returns: """ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' assert aggregation in ('sparse', 'dense'), 'Aggregation must be one of sparse or dense.' reshape = output_fmt == 'NCHW' intermediates = [] block_kwargs = block_kwargs or dict() blocks = model.blocks take_indices, max_index = _take_indices(len(blocks), indices) take_indices = sorted(take_indices) # forward pass B, _, height, width = x.shape x = patch_extractor(x) if stop_early: blocks = blocks[:max_index + 1] if inter_feature_normalizer is None or norm_alpha_scheme == 'none': inter_feature_normalizer = NullIntermediateFeatureNormalizer.get_instance(x.dtype, x.device) assert norm_alpha_scheme in ('none', 'pre-alpha', 'post-alpha'), f'Unsupported alpha scheme: {norm_alpha_scheme}' post_alpha_scheme = norm_alpha_scheme == 'post-alpha' accumulator = 0 alpha_sum = 0 num_accumulated = 0 take_off = 0 for i, blk in enumerate(blocks): x = blk(x, **block_kwargs) if aggregation == "dense": # Arbitrarily use the rotation matrix from the final layer in the dense group y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens) if post_alpha_scheme: accumulator = accumulator + y alpha_sum = alpha_sum + alpha else: accumulator = accumulator + (alpha * y) alpha_sum += 1 num_accumulated += 1 if i == take_indices[take_off]: if aggregation == "dense": alpha = alpha_sum / num_accumulated x_ = alpha * accumulator / num_accumulated num_accumulated = 0 accumulator = 0 alpha_sum = 0 else: y, alpha = inter_feature_normalizer(x, i, skip=num_summary_tokens) x_ = alpha * y # normalize intermediates with final norm layer if enabled intermediates.append(norm(x_)) take_off = min(take_off + 1, len(take_indices) - 1) # process intermediates # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, :num_cls_tokens] for y in intermediates] intermediates = [y[:, num_summary_tokens:] for y in intermediates] if reshape: # reshape to BCHW output format H = height // model.patch_size W = width // model.patch_size intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] if not torch.jit.is_scripting() and return_prefix_tokens: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(prefix_tokens, intermediates)) if intermediates_only: return intermediates x = norm(x) return x, intermediates ================================================ FILE: nit/models/nvidia_radio/radio/hf_model.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import namedtuple from typing import Callable, Dict, Optional, List, Union from timm.models import VisionTransformer import torch from torch import nn from transformers import PretrainedConfig, PreTrainedModel from .common import RESOURCE_MAP, DEFAULT_VERSION # Import all required modules. from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput from .adaptor_generic import GenericAdaptor, AdaptorBase from .adaptor_mlp import create_mlp_from_config from .adaptor_registry import adaptor_registry from .cls_token import ClsToken from .dinov2_arch import dinov2_vitg14_reg from .enable_cpe_support import enable_cpe from .enable_spectral_reparam import configure_spectral_reparam_from_args from .eradio_model import eradio from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer from .forward_intermediates import forward_intermediates from .radio_model import create_model_from_args from .radio_model import RADIOModel as RADIOModelBase, Resolution from .input_conditioner import get_default_conditioner, InputConditioner from .open_clip_adaptor import OpenCLIP_RADIO from .vit_patch_generator import ViTPatchGenerator from .vitdet import apply_vitdet_arch, VitDetArgs # Register extra models from .extra_timm_models import * from .extra_models import * class RADIOConfig(PretrainedConfig): """Pretrained Hugging Face configuration for RADIO models.""" def __init__( self, args: Optional[dict] = None, version: Optional[str] = DEFAULT_VERSION, patch_size: Optional[int] = None, max_resolution: Optional[int] = None, preferred_resolution: Optional[Resolution] = None, adaptor_names: Union[str, List[str]] = None, adaptor_configs: Dict[str, Dict[str, int]] = None, vitdet_window_size: Optional[int] = None, feature_normalizer_config: Optional[dict] = None, inter_feature_normalizer_config: Optional[dict] = None, **kwargs, ): self.args = args for field in ["dtype", "amp_dtype"]: if self.args is not None and field in self.args: # Convert to a string in order to make it serializable. # For example for torch.float32 we will store "float32", # for "bfloat16" we will store "bfloat16". self.args[field] = str(args[field]).split(".")[-1] self.version = version resource = RESOURCE_MAP[version] self.patch_size = patch_size or resource.patch_size self.max_resolution = max_resolution or resource.max_resolution self.preferred_resolution = ( preferred_resolution or resource.preferred_resolution ) self.adaptor_names = adaptor_names self.adaptor_configs = adaptor_configs self.vitdet_window_size = vitdet_window_size self.feature_normalizer_config = feature_normalizer_config self.inter_feature_normalizer_config = inter_feature_normalizer_config super().__init__(**kwargs) class RADIOModel(PreTrainedModel): """Pretrained Hugging Face model for RADIO. This class inherits from PreTrainedModel, which provides HuggingFace's functionality for loading and saving models. """ config_class = RADIOConfig def __init__(self, config: RADIOConfig): super().__init__(config) RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) args = RADIOArgs(**config.args) self.config = config model = create_model_from_args(args) input_conditioner: InputConditioner = get_default_conditioner() dtype = getattr(args, "dtype", torch.float32) if isinstance(dtype, str): # Convert the dtype's string representation back to a dtype. dtype = getattr(torch, dtype) model.to(dtype=dtype) input_conditioner.dtype = dtype summary_idxs = torch.tensor( [i for i, t in enumerate(args.teachers) if t.get("use_summary", True)], dtype=torch.int64, ) adaptor_configs = config.adaptor_configs adaptor_names = config.adaptor_names or [] adaptors = dict() for adaptor_name in adaptor_names: mlp_config = adaptor_configs[adaptor_name] adaptor = GenericAdaptor(args, None, None, mlp_config) adaptor.head_idx = mlp_config["head_idx"] adaptors[adaptor_name] = adaptor feature_normalizer = None if config.feature_normalizer_config is not None: # Actual normalization values will be restored when loading checkpoint weights. feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"]) inter_feature_normalizer = None if config.inter_feature_normalizer_config is not None: inter_feature_normalizer = IntermediateFeatureNormalizer( config.inter_feature_normalizer_config["num_intermediates"], config.inter_feature_normalizer_config["embed_dim"], rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"], dtype=dtype) self.radio_model = RADIOModelBase( model, input_conditioner, summary_idxs=summary_idxs, patch_size=config.patch_size, max_resolution=config.max_resolution, window_size=config.vitdet_window_size, preferred_resolution=config.preferred_resolution, adaptors=adaptors, feature_normalizer=feature_normalizer, inter_feature_normalizer=inter_feature_normalizer, ) @property def adaptors(self) -> nn.ModuleDict: return self.radio_model.adaptors @property def model(self) -> VisionTransformer: return self.radio_model.model @property def input_conditioner(self) -> InputConditioner: return self.radio_model.input_conditioner @property def num_summary_tokens(self) -> int: return self.radio_model.num_summary_tokens @property def patch_size(self) -> int: return self.radio_model.patch_size @property def max_resolution(self) -> int: return self.radio_model.max_resolution @property def preferred_resolution(self) -> Resolution: return self.radio_model.preferred_resolution @property def window_size(self) -> int: return self.radio_model.window_size @property def min_resolution_step(self) -> int: return self.radio_model.min_resolution_step def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: return self.radio_model.make_preprocessor_external() def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: return self.radio_model.get_nearest_supported_resolution(height, width) def switch_to_deploy(self): return self.radio_model.switch_to_deploy() def forward(self, x: torch.Tensor): return self.radio_model.forward(x) ================================================ FILE: nit/models/nvidia_radio/radio/input_conditioner.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import Union, Tuple import torch from torch import nn norm_t = Union[Tuple[float, float, float], torch.Tensor] class InputConditioner(nn.Module): def __init__(self, input_scale: float, norm_mean: norm_t, norm_std: norm_t, dtype: torch.dtype = None, ): super().__init__() self.dtype = dtype self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) def forward(self, x: torch.Tensor): y = (x - self.norm_mean) / self.norm_std if self.dtype is not None: y = y.to(self.dtype) return y def get_default_conditioner(): from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD return InputConditioner( input_scale=1.0, norm_mean=OPENAI_CLIP_MEAN, norm_std=OPENAI_CLIP_STD, ) def _to_tensor(v: norm_t): return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) ================================================ FILE: nit/models/nvidia_radio/radio/open_clip_adaptor.py ================================================ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace import torch from torch import nn import torch.nn.functional as F from .adaptor_registry import adaptor_registry, dict_t, state_t from .adaptor_generic import GenericAdaptor class OpenCLIP_RADIO(GenericAdaptor): def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): super().__init__(main_config, adaptor_config, state) import open_clip self.oc_model = open_clip.create_model_from_pretrained( model_name=adaptor_config['model'], pretrained=adaptor_config['pretrained'], return_transform=False, ) # Unload these parameters self.oc_model.visual = None self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model']) def encode_text(self, text, normalize: bool = False): return self.oc_model.encode_text(text, normalize=normalize) @adaptor_registry.register_adaptor("open_clip") def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): return OpenCLIP_RADIO(main_config, adaptor_config, state) ================================================ FILE: nit/models/nvidia_radio/radio/radio_model.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union import torch from torch import nn from timm.models import create_model, VisionTransformer from types import MethodType from .enable_cpe_support import enable_cpe from .input_conditioner import InputConditioner from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput from . import eradio_model from .enable_spectral_reparam import configure_spectral_reparam_from_args from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer from . import dual_hybrid_vit class Resolution(NamedTuple): height: int width: int class RADIOModel(nn.Module): def __init__( self, model: nn.Module, input_conditioner: InputConditioner, patch_size: int, max_resolution: int, preferred_resolution: Resolution, summary_idxs: Optional[torch.Tensor] = None, window_size: int = None, adaptors: Dict[str, AdaptorBase] = None, feature_normalizer: Optional[FeatureNormalizer] = None, inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None, ): super().__init__() self.model = model self.input_conditioner = input_conditioner if summary_idxs is not None: self.register_buffer('summary_idxs', summary_idxs) else: self.summary_idxs = None self._preferred_resolution = preferred_resolution self._patch_size = patch_size self._max_resolution = max_resolution self._window_size = window_size adaptors = adaptors or dict() self.adaptors = nn.ModuleDict(adaptors) if feature_normalizer is None: feature_normalizer = nn.Identity() self.feature_normalizer = feature_normalizer self.inter_feature_normalizer = inter_feature_normalizer @property def num_summary_tokens(self) -> int: if hasattr(self.model, 'num_summary_tokens'): return self.model.num_summary_tokens patch_gen = getattr(self.model, "patch_generator", None) if patch_gen is not None: return patch_gen.num_skip elif getattr(self.model, 'global_pool', None) == 'avg': return 0 return 1 @property def num_cls_tokens(self) -> int: if hasattr(self.model, 'num_cls_tokens'): return self.model.num_cls_tokens patch_gen = getattr(self.model, 'patch_generator', None) if patch_gen is not None: return patch_gen.num_cls_tokens elif getattr(self.model, 'global_pool', None) == 'avg': return 0 return 1 @property def patch_size(self) -> int: if self._patch_size is not None: return self._patch_size if hasattr(self.model, "patch_size"): return self.model.patch_size patch_gen = getattr(self.model, "patch_generator", None) if patch_gen is not None: return patch_gen.patch_size return None @property def max_resolution(self) -> int: return self._max_resolution @property def preferred_resolution(self) -> Resolution: return self._preferred_resolution @property def window_size(self) -> int: return self._window_size @property def min_resolution_step(self) -> int: res = self.patch_size if self.window_size is not None: res *= self.window_size return res @property def blocks(self) -> Iterable[nn.Module]: blocks = getattr(self.model, 'blocks', None) if blocks is not None: return blocks return None @property def embed_dim(self) -> int: return self.model.embed_dim def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: ret = self.input_conditioner self.input_conditioner = nn.Identity() return ret def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: height = int(round(height / self.min_resolution_step) * self.min_resolution_step) width = int(round(width / self.min_resolution_step) * self.min_resolution_step) height = max(height, self.min_resolution_step) width = max(width, self.min_resolution_step) return Resolution(height=height, width=width) def switch_to_deploy(self): fn = getattr(self.model, 'switch_to_deploy', None) if fn is not None: fn() def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ''' Forward process for model. Args: x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`, otherwise `x` is expected to be mean centered with unit standard deviation. feature_format: ['NLC', 'NCHW'] - The output format for the features. ''' res_step = self.min_resolution_step if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0): raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. ' '`self.get_nearest_supported_resolution(, ) is provided as a convenience API. ' f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}') x = self.input_conditioner(x) y = self.model.forward_features(x) ret = self._extract_final(x, y, feature_fmt=feature_fmt) return ret def forward_pack(self, x: List[torch.Tensor], feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ''' Forward process for model. Args: x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`, otherwise `x` is expected to be mean centered with unit standard deviation. feature_format: ['NLC', 'NCHW'] - The output format for the features. ''' res_step = self.min_resolution_step for _x in x: if res_step is not None and (_x.shape[-2] % res_step != 0 or _x.shape[-1] % res_step != 0): raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. ' '`self.get_nearest_supported_resolution(, ) is provided as a convenience API. ' f'Input: {_x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*_x.shape[-2:])}') x = [self.input_conditioner(_x) for _x in x] y, cu_seqlens = self.model.forward_features(x) all_summary, spatial_features = [], [] num_cls_tokens = self.model.patch_generator.num_cls_tokens num_skip = self.model.patch_generator.num_skip for i in range(len(cu_seqlens)-1): summary = y[cu_seqlens[i]: cu_seqlens[i+1]][: num_cls_tokens] all_feat = y[cu_seqlens[i]: cu_seqlens[i+1]][num_skip :] all_summary.append(summary) spatial_features.append(all_feat) all_summary = torch.cat(all_summary) spatial_features = torch.cat(spatial_features) return all_summary, spatial_features def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'): if isinstance(self.model, VisionTransformer): patch_gen = getattr(self.model, "patch_generator", None) if patch_gen is not None: all_summary = y[:, : patch_gen.num_cls_tokens] if self.summary_idxs is not None: bb_summary = all_summary[:, self.summary_idxs] else: bb_summary = all_summary all_feat = y[:, patch_gen.num_skip :] elif self.model.global_pool == "avg": all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1) bb_summary = all_summary all_feat = y else: all_summary = y[:, 0] bb_summary = all_summary all_feat = y[:, 1:] elif isinstance(self.model, eradio_model.ERADIO): _, f = y all_feat = f.flatten(2).transpose(1, 2) all_summary = all_feat.mean(dim=1) bb_summary = all_summary elif isinstance(y, (list, tuple)): all_summary, all_feat = y bb_summary = all_summary else: all_summary = y[:, :self.num_cls_tokens] if self.summary_idxs is not None and all_summary.shape[1] > 1: if all_summary.shape[1] == 1: # Create dummy duplicates all_summary = all_summary.expand(-1, 128, -1) bb_summary = all_summary[:, self.summary_idxs] else: bb_summary = all_summary all_feat = y[:, self.num_summary_tokens:] all_feat = self.feature_normalizer(all_feat) if feature_fmt == 'NCHW': fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2]) .permute(0, 3, 1, 2) ) elif feature_fmt == 'NLC': fmt_feat = all_feat else: raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]') ret = RadioOutput(bb_summary.flatten(1), fmt_feat) if self.adaptors: ret = dict(backbone=ret) for name, adaptor in self.adaptors.items(): if all_summary.ndim == 3: if all_summary.shape[1] == 1: summary = all_summary[:, 0] else: summary = all_summary[:, adaptor.head_idx] else: summary = all_summary ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size) v = adaptor(ada_input).to(torch.float32) ret[name] = v return ret def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, output_fmt: str = 'NCHW', intermediates_only: bool = False, aggregation: Optional[str] = "sparse", norm_alpha_scheme: Optional[str] = "post-alpha", ) -> List[RadioOutput]: """ Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC intermediates_only: Only return intermediate features aggregation: intermediate layer aggregation method (sparse or dense). Dense accumulation is done by averaging the features in each group. norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none") Only affects dense aggregation Returns: List of RadioOutput objects. """ x = self.input_conditioner(x) intermediates = self.model.forward_intermediates( x, indices=indices, return_prefix_tokens=return_prefix_tokens, norm=norm, stop_early=stop_early, output_fmt=output_fmt, intermediates_only=intermediates_only, aggregation=aggregation, inter_feature_normalizer=self.inter_feature_normalizer, norm_alpha_scheme=norm_alpha_scheme, ) if not intermediates_only: final, intermediates = intermediates def prepare_summary(summ: Optional[torch.Tensor]): if summ is None: return summ if self.summary_idxs is not None and summ.shape[1] > 1: summ = summ[:, self.summary_idxs] return summ.flatten(1) if return_prefix_tokens: radio_outputs = [ RadioOutput(prepare_summary(summary), features) for summary, features in intermediates ] else: radio_outputs = intermediates if intermediates_only: return radio_outputs else: final = self._extract_final(x, final, feature_fmt=output_fmt) return final, radio_outputs def create_model_from_args(args) -> nn.Module: in_chans = 3 if args.in_chans is not None: in_chans = args.in_chans elif args.input_size is not None: in_chans = args.input_size[0] # Skip weight initialization unless it's explicitly requested. weight_init = args.model_kwargs.pop("weight_init", "skip") model = create_model( args.model, pretrained=args.pretrained, in_chans=in_chans, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, drop_block_rate=args.drop_block, global_pool=args.gp, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, weight_init=weight_init, **args.model_kwargs, ) if hasattr(model, 'norm') and not getattr(args, 'model_norm', False): model.norm = nn.Identity() model.head = nn.Identity() if args.cpe_max_size is not None: uq_teachers = set(t['name'] for t in args.teachers) enable_cpe( model, args.cpe_max_size, num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1, register_multiple=getattr(args, 'register_multiple', None), num_registers=getattr(args, 'cpe_num_registers', None), support_packing=args.support_packing, ) return model ================================================ FILE: nit/models/nvidia_radio/radio/vision_transformer_xpos.py ================================================ import math from typing import Final, List, Optional, Tuple, Union from einops import rearrange from timm.models import register_model import torch from torch import Type, nn from torch.nn import functional as F from torch.nn.init import xavier_normal_, xavier_uniform_, zeros_ from .forward_intermediates import forward_intermediates def _get_init_scale(num_encoder_layers: int, num_decoder_layers: int, is_encoder: bool): if num_encoder_layers > 0 and num_decoder_layers == 0: return math.sqrt(math.log(2 * num_encoder_layers)) if num_decoder_layers > 0 and num_encoder_layers == 0: return math.sqrt(math.log(2 * num_decoder_layers)) if is_encoder: # Both encoders and decoders return math.sqrt( 0.33 * math.log(3 * num_decoder_layers) * math.log(2 * num_encoder_layers) ) return math.sqrt(math.log(3 * num_decoder_layers)) # [1,2] [1,1,2,2] # [3,4] -> [3,3,4,4] # [5,6] [5,5,6,6] def duplicate_interleave(m): return m.view(-1, 1).repeat(1, 2).view(m.shape[0], -1) # 0,1,2,3,4,5,6,7 -> -1,0,-3,2,-5,4,-7,6 def rotate_every_two(x): x1 = x[:, :, ::2] x2 = x[:, :, 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ class XPosEmbedding2D(torch.nn.Module): """Implementation of xPos based on RotaryEmbedding from GPT-NeoX. This implementation is designed to operate on queries and keys that are compatible with [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format). """ def __init__( self, head_dim: int, base=50000, scale_base=512 ): super().__init__() half_dim = head_dim // 2 self.half_dim = half_dim inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.head_dim = head_dim self.token_shape_cached = None self.batch_size_cached = None self.cos_cached: torch.Tensor | None = None self.sin_cached: torch.Tensor | None = None self.scale_cached: torch.Tensor | None = None self.scale_base = scale_base self.register_buffer("scale", (torch.arange(0, half_dim, 2) + 0.4 * half_dim) / (1.4 * half_dim)) def cos_sin( self, token_shape: Tuple[int, int], device="cuda", dtype=torch.bfloat16, ) -> torch.Tensor: if token_shape != self.token_shape_cached: self.token_shape_cached = token_shape y = torch.arange(token_shape[0], device=device, dtype=self.inv_freq.dtype) x = torch.arange(token_shape[1], device=device, dtype=self.inv_freq.dtype) x, y = torch.meshgrid(x, y, indexing='xy') y_freqs = torch.einsum("i,j->ij", y.flatten(), self.inv_freq) x_freqs = torch.einsum("i,j->ij", x.flatten(), self.inv_freq) y_scales = self.scale ** y.flatten().div(self.scale_base)[:, None] x_scales = self.scale ** x.flatten().div(self.scale_base)[:, None] freqs = torch.cat([y_freqs, x_freqs], dim=-1) emb = torch.repeat_interleave(freqs, repeats=2, dim=-1) scales = torch.cat([y_scales, x_scales], dim=-1) scales = torch.repeat_interleave(scales, repeats=2, dim=-1) if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() self.cos_cached = emb.cos()[None, :, :] self.sin_cached = emb.sin()[None, :, :] self.scale_cached = scales[None, :, :] self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) self.scale_cached = self.scale_cached.type(dtype) return self.cos_cached, self.sin_cached, self.scale_cached def forward(self, q: torch.Tensor, k: torch.Tensor, token_shape: Tuple[int, int]): batch, seq_len, head_dim = q.shape cos, sin, scale = self.cos_sin(token_shape, q.device, q.dtype) # scale = self.scale**torch.arange(seq_len).to(self.scale).div(self.scale_base)[:, None] # scale = torch.repeat_interleave(scale, 2, dim=-1).to(q.device) # scale = torch.cat([scale, scale], dim=-1) # scale = 1 return ( (q * cos * scale) + (rotate_every_two(q) * sin * scale), (k * cos * (1 / scale)) + (rotate_every_two(k) * sin * (1 / scale)), ) class MagnetoAttention(nn.Module): def __init__(self, d_model: int, n_head: int, pos_emb: XPosEmbedding2D): super().__init__() self.num_heads = n_head self.head_dim = d_model // n_head self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(d_model, d_model * 3, bias=False) self.proj = nn.Linear(d_model, d_model) self.pos_emb = pos_emb self.norm0 = nn.LayerNorm(d_model) self.norm1 = nn.LayerNorm(d_model) def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor: B, N, C = x.shape x = self.norm0(x) qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3) q, k, v = qkv.unbind(0) q_pref = q[:, :num_prefix_tokens] q_patch = q[:, num_prefix_tokens:] k_pref = k[:, :num_prefix_tokens] k_patch = k[:, num_prefix_tokens:] q_patch, k_patch = self.pos_emb(q_patch, k_patch, patch_shape) q = torch.cat([q_pref, q_patch], dim=1) k = torch.cat([k_pref, k_patch], dim=1) def head_reshape(t: torch.Tensor): return t.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) q = head_reshape(q) k = head_reshape(k) v = head_reshape(v) x = F.scaled_dot_product_attention(q, k, v) x = x.transpose(1, 2).reshape(B, N, C) x = self.norm1(x) x = self.proj(x) return x def _reset_parameters(self): xavier_uniform_(self.qkv.weight) if self.qkv.bias is not None: zeros_(self.qkv.bias) xavier_normal_(self.proj.weight) zeros_(self.proj.bias) class MagnetoTransformerEncoderLayer(nn.Module): def __init__(self, d_model: int, nhead: int, pos_emb: XPosEmbedding2D, num_encoder_layers: int, num_decoder_layers: int = 0, dim_mhsa: int = 0, dim_feedforward: int = 2048, layer_norm_eps: float = 1e-5, batch_first: bool = True): super().__init__() if dim_mhsa == 0: dim_mhsa = d_model self._num_encoder_layers = num_encoder_layers self._num_decoder_layers = num_decoder_layers self.attn = MagnetoAttention(d_model, nhead, pos_emb) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.linear2 = nn.Linear(d_model, dim_feedforward) self.norm3 = nn.LayerNorm(dim_feedforward, eps=layer_norm_eps) self.linear3 = nn.Linear(dim_feedforward, d_model) def initialize(self): gamma = _get_init_scale(self._num_encoder_layers, self._num_decoder_layers, is_encoder=True) # Magneto Initialization for mod in self.children(): if isinstance(mod, nn.Linear): xavier_normal_(mod.weight.data, gamma) elif isinstance(mod, MagnetoAttention): mod._reset_parameters() def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor: x = x + self._sa_block(x, num_prefix_tokens, patch_shape) x = x + self._ff_block(x) return x def _sa_block(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor: x = self.attn(x, num_prefix_tokens, patch_shape) return x def _ff_block(self, x: torch.Tensor) -> torch.Tensor: x = self.norm2(x) x = self.linear2(x) x = F.gelu(x) x = self.norm3(x) x = self.linear3(x) return x class VisionTransformer(nn.Module): """ Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ dynamic_img_size: Final[bool] def __init__( self, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4., num_cls_tokens: int = 1, num_reg_tokens: int = 0, ) -> None: """ Args: patch_size: Patch size. in_chans: Number of image input channels. embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. num_cls_tokens: Number of cls tokens num_reg_tokens: Number of register tokens. block_fn: Transformer block layer. """ super().__init__() self.patch_size = patch_size self.embed_dim = embed_dim self.num_cls_tokens = num_cls_tokens self.num_reg_tokens = num_reg_tokens self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.prefix_buffer = nn.Parameter(torch.randn(1, self.num_prefix_tokens, embed_dim) * .02) pos_emb = XPosEmbedding2D(embed_dim) self.blocks = nn.ModuleList([ MagnetoTransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, num_encoder_layers=depth, num_decoder_layers=0, dim_feedforward=int(embed_dim * mlp_ratio), pos_emb=pos_emb, ) for _ in range(depth) ]) for block in self.blocks: block.initialize() @property def num_prefix_tokens(self): return self.num_cls_tokens + self.num_reg_tokens @property def num_summary_tokens(self): return self.num_prefix_tokens def forward_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: x, patch_shape = self._patchify(x) for block in self.blocks: x = block(x, self.num_prefix_tokens, patch_shape) summary = x[:, :self.num_cls_tokens] features = x[:, self.num_prefix_tokens:] return summary, features def forward_intermediates(self, x: torch.Tensor, norm: bool = False, **kwargs): patch_shape = tuple(d // self.patch_size for d in x.shape[-2:]) def patch_extractor(x: torch.Tensor): x, _ = self._patchify(x) return x return forward_intermediates( self, patch_extractor=patch_extractor, num_summary_tokens=self.num_prefix_tokens, num_cls_tokens=self.num_cls_tokens, norm=lambda y: y, x=x, block_kwargs=dict(num_prefix_tokens=self.num_prefix_tokens, patch_shape=patch_shape), **kwargs, ) def _patchify(self, x: torch.Tensor): x = self.patch_embed(x) patch_shape = x.shape[-2:] x = rearrange(x, 'b c h w -> b (h w) c') prefix = self.prefix_buffer.expand(x.shape[0], -1, -1) x = torch.cat([prefix, x], dim=1) return x, patch_shape @register_model def vit_base_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer: return VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens) @register_model def vit_large_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer: return VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens) @register_model def vit_huge_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer: return VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens) @register_model def vit_giant_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer: return VisionTransformer(patch_size=16, embed_dim=1408, depth=40, num_heads=16, num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens) @register_model def vit_bigG_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer: return VisionTransformer(patch_size=16, embed_dim=1664, depth=48, num_heads=16, num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens) ================================================ FILE: nit/models/nvidia_radio/radio/vit_patch_generator.py ================================================ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import math from typing import Union, Tuple, Optional import torch import torch.nn.functional as F from torch import nn from einops import rearrange from .cls_token import ClsToken input_dim_t = Union[int, Tuple[int, int]] try: # raise ImportError() from indirect_grid_sample import indirect_grid_sample except ImportError: indirect_grid_sample = None class ViTPatchGenerator(nn.Module): def __init__(self, patch_size: int, embed_dim: int, input_dims: input_dim_t, abs_pos: bool = True, normalize_patches: bool = False, cls_token: bool = False, max_input_dims: Optional[input_dim_t] = None, pos_dropout: float = 0.0, return_pos_enc: bool = False, num_cls_tokens: int = 1, register_multiple: Optional[int] = None, num_registers: Optional[int] = None, patch_bias: bool = False, device=None, dtype=None, ): super().__init__() if isinstance(input_dims, int): input_dims = (input_dims, input_dims) if max_input_dims is None: max_input_dims = input_dims if isinstance(max_input_dims, int): max_input_dims = (max_input_dims, max_input_dims) max_input_dims = tuple( int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims ) self.cpe_mode = max_input_dims != input_dims self.pos_dropout = pos_dropout self.return_pos_enc = return_pos_enc factory = dict(device=device, dtype=dtype) self.patch_size = patch_size self.abs_pos = abs_pos self.embed_dim = embed_dim self.num_rows = max_input_dims[0] // patch_size self.num_cols = max_input_dims[1] // patch_size self.input_dims = tuple(d // patch_size for d in input_dims) self.num_patches = self.num_rows * self.num_cols self.max_input_dims = max_input_dims self.im_to_patches = Im2Patches(patch_size) self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias, **factory) if abs_pos: scale = embed_dim ** -0.5 self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale) self.cls_token = ClsToken( embed_dim, num_tokens=num_cls_tokens, enabled=cls_token, register_multiple=register_multiple, num_registers=num_registers, ) self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: patches = self.embed_patches(x) patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:]) patches = self.cls_token(patches) patches = self.patch_normalizer(patches) if self.return_pos_enc: return patches, pos_enc return patches @property def apply_cls_token(self): return self.cls_token.enabled @property def num_cls_tokens(self): return self.cls_token.num_tokens @property def num_cls_patches(self): return self.cls_token.num_patches @property def num_registers(self): return self.cls_token.num_registers @property def num_skip(self): return self.num_cls_tokens + self.num_registers def no_weight_decay(self): return [ 'pos_embed', ] def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter): if src_embed.shape != targ_embed.shape: src_size = int(math.sqrt(src_embed.shape[1])) assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding' src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size) src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False) src_embed = rearrange(src_embed, 'b c h w -> b (h w) c') targ_embed.data.copy_(src_embed) def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor): if src_proj_weight.shape != targ_proj_weight.shape: src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3)) assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size' src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size) src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False) src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)') targ_proj_weight.data.copy_(src_proj_weight) def embed_patches(self, x: torch.Tensor) -> torch.Tensor: patches = self.im_to_patches(x) patches = self.embedder(patches) return patches def apply_pos_enc(self, patches: torch.Tensor, patch_idxs: Optional[torch.Tensor] = None, input_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: if not self.abs_pos: return patches pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size) if self.training and self.pos_dropout > 0: keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout pos_enc_drop = torch.where(keeps, pos_enc, 0) else: pos_enc_drop = pos_enc return patches + pos_enc_drop, pos_enc def get_pos_enc(self, batch_size: int, patch_idxs: Optional[torch.Tensor] = None, input_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: if input_size is None: input_dims = self.input_dims else: input_dims = tuple(d // self.patch_size for d in input_size) pos_embed = self._get_pos_embeddings(batch_size, input_dims) if patch_idxs is None: return pos_embed exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]) pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs) return pos_embed def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]): if (self.num_rows, self.num_cols) == input_dims: return self.pos_embed pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2) def window_select(pos_embed): if input_dims[0] < pos_embed.shape[-2]: pos_embed = pos_embed[..., :input_dims[0], :] if input_dims[1] < pos_embed.shape[-1]: pos_embed = pos_embed[..., :, :input_dims[1]] return pos_embed if self.cpe_mode: if self.training: min_scale = math.sqrt(0.1) scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale aspect_min = math.log(3 / 4) aspect_max = -aspect_min aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min) scale_x = scale * aspect scale_y = scale * (1 / aspect) scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1) pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy) lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1) lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1]) lin_xy = torch.stack([lin_x, lin_y], dim=-1) grid_xy = lin_xy * scale_xy + pos_xy # Convert to [-1, 1] range grid_xy.mul_(2).sub_(1) pos_embed = F.grid_sample( pos_embed.float().expand(batch_size, -1, -1, -1), grid=grid_xy, mode='bilinear', padding_mode='zeros', align_corners=True, ).to(pos_embed.dtype) else: # i_rows, i_cols = input_dims # p_rows, p_cols = pos_embed.shape[2:] # if i_rows <= p_rows and i_cols <= p_cols: # left = (p_cols - i_cols) // 2 # top = (p_rows - i_rows) // 2 # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols] # else: max_dim = max(input_dims) pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype) pos_embed = window_select(pos_embed) else: pos_embed = window_select(pos_embed) if pos_embed.shape[-2:] != input_dims: pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype) pos_embed = pos_embed.flatten(2).permute(0, 2, 1) return pos_embed class Im2Patches(nn.Module): def __init__(self, patch_size: int): super().__init__() self.patch_size = patch_size def forward(self, x: torch.Tensor) -> torch.Tensor: if self.patch_size == 1: patches = x.flatten(2) patches = patches.permute(0, 2, 1) return patches py = x.shape[-2] // self.patch_size px = x.shape[-1] // self.patch_size patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)', py=py, yy=self.patch_size, px=px, xx=self.patch_size, ) return patches class ViTPatchLinear(nn.Linear): def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory): super().__init__( 3 * (patch_size ** 2), embed_dim, bias=bias, **factory ) self.patch_size = patch_size ================================================ FILE: nit/models/nvidia_radio/radio/vitdet.py ================================================ from collections import defaultdict from contextlib import contextmanager from logging import getLogger import math import sys from typing import List, Union, Iterable import numpy as np import torch from torch import nn from timm.models import VisionTransformer from einops import rearrange from .extra_models import DinoWrapper DEFAULT_NUM_WINDOWED = 5 DEFAULT_NUM_GLOBAL = 4 class VitDetArgs: def __init__(self, window_size: int, num_summary_tokens: int, num_windowed: int = None, num_global: int = None, ): self.window_size = window_size self.num_summary_tokens = num_summary_tokens self.num_windowed = num_windowed self.num_global = num_global def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs): if isinstance(model, VisionTransformer): patch_embed = getattr(model, 'patch_generator', model.patch_embed) return ViTDetHook(patch_embed, model.blocks, args) elif isinstance(model, DinoWrapper): inner = model.inner patch_embed = getattr(inner, 'patch_generator', inner.patch_embed) return ViTDetHook(patch_embed, inner.blocks, args) else: print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr) class ViTDetHook: def __init__(self, embedder: nn.Module, blocks: nn.Sequential, args: VitDetArgs, ): self.blocks = blocks self.num_summary_tokens = args.num_summary_tokens self.window_size = args.window_size self._input_resolution = None self._num_windows = None self._cls_patch = None self._order_cache = dict() embedder.register_forward_pre_hook(self._enter_model) # This will decide if we window-fy the patches # and enable vit-det for this iteration, and if so, # rearrange the patches for efficient mode switching blocks.register_forward_pre_hook(self._enter_blocks) is_global = True if args.num_windowed is not None: period = args.num_windowed + 1 else: num_global = args.num_global or DEFAULT_NUM_GLOBAL period = max(len(blocks) // num_global, 1) for i, layer in enumerate(blocks[:-1]): ctr = i % period if ctr == 0: layer.register_forward_pre_hook(self._to_windows) is_global = False elif ctr == period - 1: layer.register_forward_pre_hook(self._to_global) is_global = True # Always ensure the final layer is a global layer if not is_global: blocks[-1].register_forward_pre_hook(self._to_global) blocks.register_forward_hook(self._exit_model) def _enter_model(self, _, input: List[torch.Tensor]): self._input_resolution = input[0].shape[-2:] def _enter_blocks(self, _, input: List[torch.Tensor]): # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr) patches = input[0] patches = self._rearrange_patches(patches) return (patches,) + input[1:] def _to_windows(self, _, input: List[torch.Tensor]): patches = input[0] if self.num_summary_tokens: self._cls_patch = patches[:, :self.num_summary_tokens] patches = patches[:, self.num_summary_tokens:] patches = rearrange( patches, 'b (p t) c -> (b p) t c', p=self._num_windows, t=self.window_size ** 2, ) return (patches,) + input[1:] def _to_global(self, _, input: List[torch.Tensor]): patches = input[0] patches = rearrange( patches, '(b p) t c -> b (p t) c', p=self._num_windows, t=self.window_size ** 2, b=patches.shape[0] // self._num_windows, ) if self.num_summary_tokens: patches = torch.cat([ self._cls_patch, patches, ], dim=1) return (patches,) + input[1:] def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor): # Return patches to their original order patch_order = self._order_cache[self._input_resolution][0] patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) ret_patches = torch.empty_like(patches) ret_patches = torch.scatter( ret_patches, dim=1, index=patch_order, src=patches, ) return ret_patches def _rearrange_patches(self, patches: torch.Tensor): # We rearrange the patches so that we can efficiently # switch between windowed and global mode by just # reshaping the tensor patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None)) if patch_order is None: num_feat_patches = patches.shape[1] - self.num_summary_tokens num_pixels = self._input_resolution[0] * self._input_resolution[1] patch_size = int(round(math.sqrt(num_pixels / num_feat_patches))) rows = self._input_resolution[-2] // patch_size cols = self._input_resolution[-1] // patch_size w_rows = rows // self.window_size w_cols = cols // self.window_size patch_order = torch.arange(0, num_feat_patches, device=patches.device) patch_order = rearrange( patch_order, '(wy py wx px) -> (wy wx py px)', wy=w_rows, wx=w_cols, py=self.window_size, px=self.window_size, ) if self.num_summary_tokens: patch_order = torch.cat([ torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device), patch_order + self.num_summary_tokens, ]) self._num_windows = w_rows * w_cols self._order_cache[self._input_resolution] = ( patch_order, self._num_windows, ) patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) patches = torch.gather(patches, dim=1, index=patch_order) return patches ================================================ FILE: nit/models/utils/convs.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from nit.models.efficientvit.models.nn.ops import ConvLayer from nit.models.efficientvit.models.nn.act import build_act from nit.models.efficientvit.models.utils import val2tuple def create_conv_1(conv_type, in_channels, out_channels, norm, act_func, groups=1): ''' conv_type: dwconv_3x3_1, dsconv_3x3_1, dgconv_3x3_1 ''' if conv_type == None or conv_type == "": return nn.Identity() splited_conv_type = conv_type.split('_') conv_type = splited_conv_type[0] kernel_size = int(splited_conv_type[1].split('x')[0]) stride = int(splited_conv_type[2]) if conv_type == 'dwconv': return DWConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func) elif conv_type == 'dsconv': return DSConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func) elif conv_type == 'dgconv': return DGConv(in_channels, out_channels, kernel_size, stride, groups, norm=norm, act_func=act_func) else: return nn.Identity() def create_conv_2(conv_type, in_channels, out_channels, mid_channels): ''' conv_type: mbconv_3x3_1, fusedmbconv_3x3_1, glumbconv_3x3_1 ''' if conv_type == None or conv_type == "": return nn.Identity() splited_conv_type = conv_type.split('_') conv_type = splited_conv_type[0] kernel_size = int(splited_conv_type[1].split('x')[0]) stride = int(splited_conv_type[2]) if conv_type == 'mbconv': return MBConv(in_channels, out_channels, kernel_size, stride, mid_channels) elif conv_type == 'fusedmbconv': return FusedMBConv(in_channels, out_channels, kernel_size, stride, mid_channels) elif conv_type == 'glumbconv': return GLUMBConv(in_channels, out_channels, kernel_size, stride, mid_channels) else: return nn.Identity() class DWConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, use_bias=True, norm="bn2d", act_func="relu6", ): super(DWConv, self).__init__() self.depth_conv = ConvLayer( in_channels, out_channels, kernel_size, stride, groups=in_channels, norm=norm, act_func=act_func, use_bias=use_bias, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.depth_conv(x) return x class DSConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, use_bias=(True, True), norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super(DSConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) self.depth_conv = ConvLayer( in_channels, in_channels, kernel_size, stride, groups=in_channels, norm=norm[0], act_func=act_func[0], use_bias=use_bias[0], ) self.point_conv = ConvLayer( in_channels, out_channels, 1, norm=norm[1], act_func=act_func[1], use_bias=use_bias[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.depth_conv(x) x = self.point_conv(x) return x class DGConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, groups=16, use_bias=(True, True), norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super(DGConv, self).__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) self.depth_conv = ConvLayer( in_channels, in_channels, kernel_size, stride, groups=in_channels, norm=norm[0], act_func=act_func[0], use_bias=use_bias[0], ) self.point_conv = ConvLayer( in_channels, out_channels, 1, groups=groups, norm=norm[1], act_func=act_func[1], use_bias=use_bias[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.depth_conv(x) x = self.point_conv(x) return x class MBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, use_bias=True, norm=("bn2d", "bn2d", "bn2d"), act_func=("relu6", "relu6", None), ): super(MBConv, self).__init__() use_bias = val2tuple(use_bias, 3) norm = val2tuple(norm, 3) act_func = val2tuple(act_func, 3) mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels self.inverted_conv = ConvLayer( in_channels, mid_channels, 1, stride=1, norm=norm[0], act_func=act_func[0], use_bias=use_bias[0], ) self.depth_conv = ConvLayer( mid_channels, mid_channels, kernel_size, stride=stride, groups=mid_channels, norm=norm[1], act_func=act_func[1], use_bias=use_bias[1], ) self.point_conv = ConvLayer( mid_channels, out_channels, 1, norm=norm[2], act_func=act_func[2], use_bias=use_bias[2], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x = self.point_conv(x) return x class FusedMBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, groups=1, use_bias=True, norm=("bn2d", "bn2d"), act_func=("relu6", None), ): super().__init__() use_bias = val2tuple(use_bias, 2) norm = val2tuple(norm, 2) act_func = val2tuple(act_func, 2) mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels self.spatial_conv = ConvLayer( in_channels, mid_channels, kernel_size, stride, groups=groups, use_bias=use_bias[0], norm=norm[0], act_func=act_func[0], ) self.point_conv = ConvLayer( mid_channels, out_channels, 1, use_bias=use_bias[1], norm=norm[1], act_func=act_func[1], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.spatial_conv(x) x = self.point_conv(x) return x class GLUMBConv(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size=3, stride=1, mid_channels=None, expand_ratio=6, use_bias=True, norm=(None, None, "ln2d"), act_func=("silu", "silu", None), ): super().__init__() use_bias = val2tuple(use_bias, 3) norm = val2tuple(norm, 3) act_func = val2tuple(act_func, 3) mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels self.glu_act = build_act(act_func[1], inplace=False) self.inverted_conv = ConvLayer( in_channels, mid_channels * 2, 1, use_bias=use_bias[0], norm=norm[0], act_func=act_func[0], ) self.depth_conv = ConvLayer( mid_channels * 2, mid_channels * 2, kernel_size, stride=stride, groups=mid_channels * 2, use_bias=use_bias[1], norm=norm[1], act_func=None, ) self.point_conv = ConvLayer( mid_channels, out_channels, 1, use_bias=use_bias[2], norm=norm[2], act_func=act_func[2], ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) gate = self.glu_act(gate) x = x * gate x = self.point_conv(x) return x ================================================ FILE: nit/models/utils/funcs.py ================================================ import torch from torch import Tensor from typing import List, Tuple from itertools import chain def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) def get_parameter_dtype(parameter: torch.nn.Module): try: params = tuple(parameter.parameters()) if len(params) > 0: return params[0].dtype buffers = tuple(parameter.buffers()) if len(buffers) > 0: return buffers[0].dtype except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype ================================================ FILE: nit/models/utils/norms.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import math from functools import partial import torch import torch.nn as nn import triton import triton.language as tl import torch.nn.functional as F def create_norm(norm_type: str, dim: int, eps: float = 1e-6): """ Creates the specified normalization layer based on the norm_type. Args: norm_type (str): The type of normalization layer to create. Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm dim (int): The dimension of the normalization layer. eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6. Returns: The created normalization layer. Raises: NotImplementedError: If an unknown norm_type is provided. """ if norm_type == None or norm_type == "": return nn.Identity() norm_type = norm_type.lower() # Normalize to lowercase if norm_type == "layernorm": return nn.LayerNorm(dim, eps=eps, bias=False) elif norm_type == "np_layernorm": return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) elif norm_type == "np_layernorm_32": return FP32_Layernorm(dim, eps=eps, elementwise_affine=False, bias=True) elif norm_type == "layernorm_32": return FP32_Layernorm(dim, eps=eps, bias=True) elif norm_type == "rmsnorm": return RMSNorm(dim, include_weight=True, eps=eps) elif norm_type == "np_rmsnorm": return RMSNorm(dim, include_weight=False, eps=1e-6) elif norm_type == "fused_rmsnorm": return FusedRMSNorm(dim, eps=1/65536) elif norm_type == "fused_rmsnorm_32": return FusedRMSNorm32(dim, eps=1e-6) elif norm_type == 'none': return nn.Identity() else: return nn.Identity() class FP32_Layernorm(nn.LayerNorm): def forward(self, inputs: torch.Tensor) -> torch.Tensor: origin_dtype = inputs.dtype if self.bias == None and self.weight == None: return F.layer_norm( input=inputs.float(), normalized_shape=self.normalized_shape, eps=self.eps ).to(origin_dtype) elif self.bias == None: return F.layer_norm( input=inputs.float(), normalized_shape=self.normalized_shape, weight=self.weight.float(), eps=self.eps ).to(origin_dtype) else: return F.layer_norm( input=inputs.float(), normalized_shape=self.normalized_shape, weight=self.weight.float(), bias=self.bias.float(), eps=self.eps ).to(origin_dtype) class FusedRMSNorm(nn.Module): """Fused RMS Norm, wraps a fused Triton Kernel""" def __init__( self, dim: int, eps: float = 1e-6, ): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self.fused_rms_norm_fn = fused_rms_norm_fn def forward(self, x: torch.Tensor) -> torch.Tensor: """leverages Triton Fused RMS Norm kernel""" return self.fused_rms_norm_fn( x, self.weight, eps=self.eps, ) def reset_parameters(self): torch.nn.init.ones_(self.weight) # type: ignore class FusedRMSNorm32(nn.Module): """Fused RMS Norm, wraps a fused Triton Kernel""" def __init__( self, dim: int, eps: float = 1e-6, ): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) self.fused_rms_norm_fn = fused_rms_norm_fn def forward(self, x: torch.Tensor) -> torch.Tensor: """leverages Triton Fused RMS Norm kernel""" dtype = x.dtype return self.fused_rms_norm_fn( x.to(torch.float32), self.weight, eps=self.eps, ).to(dtype) def reset_parameters(self): torch.nn.init.ones_(self.weight) # type: ignore class RMSNorm(nn.Module): def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6, **block_kwargs): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. include_weight: bool: Whether include weight in the normalization eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ super().__init__() self.eps = eps if include_weight: self.weight = nn.Parameter(torch.ones(dim)) else: self.weight = None def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ output = self._norm(x.float()).type_as(x) if self.weight == None: return output else: return output * self.weight # FusedRMSNorm in Triton # Credit # Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py # Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], key=["N"], ) @triton.jit def _rms_norm_fwd_kernel( X, stride_x, Y, stride_y, W, Rstd, eps, M, # num rows N, # num cols block_N: tl.constexpr, ): row = tl.program_id(0) cols = tl.arange(0, block_N) # Load input data and weights mask = cols < N x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) # Compute mean and variance xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) # Store the reciprocal standard deviation tl.store(Rstd + row, rstd) # Normalize and apply linear transformation x_hat = x * rstd y = x_hat * w # Write output tl.store(Y + row * stride_y + cols, y, mask=mask) @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), triton.Config({}, num_warps=16), triton.Config({}, num_warps=32), ], key=["N"], ) @triton.jit def _rms_norm_bwd_kernel_sm( X, stride_x, W, DY, stride_dy, DX, stride_dx, Rstd, DW, eps, M, # num rows N, # num cols rows_per_program, block_N: tl.constexpr, ): row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program cols = tl.arange(0, block_N) mask = cols < N # Load weights w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) # Accumulate gradients for weights dw = tl.zeros((block_N,), dtype=tl.float32) row_end = min(row_start + rows_per_program, M) for row in range(row_start, row_end): # Load input, output gradient, and reciprocal standard deviation x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) rstd = tl.load(Rstd + row) # Compute normalized input and gradients x_hat = x * rstd wdy = w * dy dw += dy * x_hat c1 = tl.sum(x_hat * wdy, axis=0) / N dx = (wdy - x_hat * c1) * rstd # Store input gradient tl.store(DX + row * stride_dx + cols, dx, mask=mask) # Store weight gradients tl.store(DW + row_block_id * N + cols, dw, mask=mask) class TritonFusedRMSNorm(torch.autograd.Function): @staticmethod def forward(ctx, x, weight, eps): x_shape_start = x.shape # Flatten input x = x.view(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if weight.stride(-1) != 1: weight = weight.contiguous() M, N = x.shape y = torch.empty_like(x) rstd = torch.empty((M,), dtype=torch.float32, device=x.device) max_size = 65536 // x.element_size() block_N = min(max_size, triton.next_power_of_2(N)) if N > block_N: raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (M,) _rms_norm_fwd_kernel[grid]( x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N, ) ctx.eps = eps ctx.save_for_backward(x, weight, rstd) ctx.x_shape_start = x_shape_start y = y.reshape(x_shape_start) return y @staticmethod def backward(ctx, dy): x, weight, rstd = ctx.saved_tensors eps = ctx.eps x_shape_start = ctx.x_shape_start # Flatten input and output gradients dy = dy.view(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() M, N = dy.shape dx = torch.empty_like(x) dw = torch.empty_like(weight) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) max_size = 65536 // x.element_size() block_N = min(max_size, triton.next_power_of_2(N)) rows_per_sm = math.ceil(M / sm_count) if N > block_N: raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (sm_count,) _rms_norm_bwd_kernel_sm[grid]( x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, _dw, eps, M, N, rows_per_sm, block_N, ) dw = _dw.sum(0).to(weight.dtype) dx = dx.view(x_shape_start) return dx, dw, None # expose fusedRMSNorm as a function def fused_rms_norm_fn( x, weight, eps=1e-6, ): return TritonFusedRMSNorm.apply( x, weight, eps, ) ================================================ FILE: nit/models/utils/pos_embeds/flash_attn_rotary.py ================================================ # Copyright (c) 2023, Tri Dao. import math from typing import Optional, Tuple, Union import torch from einops import rearrange, repeat from flash_attn.ops.triton.rotary import apply_rotary def rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) def apply_rotary_emb_torch(x, cos, sin, interleaved=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") return torch.cat( [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], dim=-1, ) class ApplyRotaryEmb(torch.autograd.Function): @staticmethod def forward( ctx, x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): out = apply_rotary( x, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=interleaved, inplace=inplace, ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.inplace = inplace ctx.max_seqlen = max_seqlen return out if not inplace else x @staticmethod def backward(ctx, do): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: cos, sin, cu_seqlens = ctx.saved_tensors # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. if not ctx.interleaved and not ctx.inplace: do = do.clone() dx = apply_rotary( do, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=ctx.inplace, conjugate=True, ) return dx, None, None, None, None, None, None, None def apply_rotary_emb( x, cos, sin, interleaved=False, inplace=False, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): """ Arguments: x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) cos, sin: (seqlen_rotary, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). inplace: if True, apply rotary embedding in-place. seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. cu_seqlens: (batch + 1,) or None max_seqlen: int Return: out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None else (total_seqlen, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding to the first rotary_dim of x. """ return ApplyRotaryEmb.apply( x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen ) # For backward compatibility apply_rotary_emb_func = apply_rotary_emb #TODO need check ,whlzy modified!!!! class ApplyRotaryEmbQKV_(torch.autograd.Function): @staticmethod def forward( ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): total, three, nheads, headdim = qkv.shape # (total, 3, nheads, headdim) assert three == 3 if cos_k is None and sin_k is None and qkv.is_contiguous(): # Call 1 kernel instead of 2 kernels # We need qkv to be contiguous so that when we reshape to combine (3, nheads) # dimensions, we get the same tensor # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") qk = qkv[:, :2].reshape(total, -1, headdim) apply_rotary( qk, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=interleaved, inplace=True ) else: cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k q, k = qkv[:, 0], qkv[:, 1] apply_rotary( q, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=interleaved, inplace=True ) apply_rotary( k, cos_k, sin_k, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, interleaved=interleaved, inplace=True ) ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens) ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved ctx.max_seqlen = max_seqlen return qkv @staticmethod def backward(ctx, dqkv): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors else: cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors if cos_k is None and sin_k is None and dqkv.is_contiguous(): # Call 1 kernel instead of 2 kernels # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) # dimensions, we get the same tensor dqk = rearrange(dqkv[:, :, :2], "b t h d -> b (t h) d") # b for total apply_rotary( dqk, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=True, conjugate=True, ) else: cos_k = cos if cos_k is None else cos_k sin_k = sin if sin_k is None else sin_k dq, dk = dqkv[:, 0], dqkv[:, 1] apply_rotary( dq, cos, sin, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=True, conjugate=True ) apply_rotary( dk, cos_k, sin_k, seqlen_offsets=seqlen_offsets, cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen, interleaved=ctx.interleaved, inplace=True, conjugate=True, ) return dqkv, None, None, None, None, None, None, None, None #TODO need check ,whlzy modified!!!! def apply_rotary_emb_qkv_( qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, ): """ Arguments: qkv: (batch_size, seqlen, 3, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) cos_k, sin_k: (seqlen, rotary_dim / 2), optional interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. Most commonly used in inference when we have KV cache. Return: qkv: (batch_size, seqlen, 3, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of Q and K. """ return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen) class ApplyRotaryEmbKV_(torch.autograd.Function): @staticmethod def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): batch, seqlen, two, nheads, headdim = kv.shape assert two == 2 k = kv[:, :, 0] apply_rotary( k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True ) if isinstance(seqlen_offsets, int): ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward ctx.seqlen_offsets = seqlen_offsets else: ctx.save_for_backward(cos, sin, seqlen_offsets) ctx.seqlen_offsets = None ctx.interleaved = interleaved return kv @staticmethod def backward(ctx, dkv): seqlen_offsets = ctx.seqlen_offsets if seqlen_offsets is None: cos, sin, seqlen_offsets = ctx.saved_tensors else: cos, sin = ctx.saved_tensors apply_rotary( dkv[:, :, 0], cos, sin, seqlen_offsets=seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True, ) return dkv, None, None, None, None apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply def apply_rotary_emb_kv_( kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0, ): """ Arguments: kv: (batch_size, seqlen, 2, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. Most commonly used in inference when we have KV cache. Return: kv: (batch_size, seqlen, 2, nheads, headdim) rotary_dim must be <= headdim Apply rotary embedding *inplace* to the first rotary_dim of K. """ return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets) class RotaryEmbedding(torch.nn.Module): """ The rotary position embeddings from RoFormer_ (Su et. al). A crucial insight from the method is that the query and keys are transformed by rotation matrices which depend on the relative positions. Other implementations are available in the Rotary Transformer repo_ and in GPT-NeoX_, GPT-NeoX was an inspiration .. _RoFormer: https://arxiv.org/abs/2104.09864 .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ def __init__( self, dim: int, base=10000.0, interleaved=False, scale_base=None, pos_idx_in_fp32=True, device=None, ): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. This option was added because previously (before 2023-07-02), when we construct the position indices, we use the dtype of self.inv_freq. In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then self.inv_freq would be bf16, and the position indices are also in bf16. Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the embeddings for some positions will coincide. To maintain compatibility with models previously trained in pure bf16, we add this option. """ super().__init__() self.dim = dim self.base = float(base) self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) inv_freq = self._compute_inv_freq(device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.interleaved = interleaved self.scale_base = scale_base scale = ( (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) if scale_base is not None else None ) self.register_buffer("scale", scale, persistent=False) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None def _compute_inv_freq(self, device=None): return 1.0 / ( self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim) ) def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): # Reset the tables if the sequence length has changed, # if we're on a new device (possibly due to tracing for instance), # or if we're switching from inference mode to training if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) ): self._seq_len_cached = seqlen # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # And the output of arange can be quite large, so bf16 would lose a lot of precision. # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. if self.pos_idx_in_fp32: t = torch.arange(seqlen, device=device, dtype=torch.float32) # We want fp32 here as well since inv_freq will be multiplied with t, and the output # will be large. Having it in bf16 will lose a lot of precision and cause the # cos & sin output to change significantly. # We want to recompute self.inv_freq if it was not loaded in fp32 if self.inv_freq.dtype != torch.float32: inv_freq = self._compute_inv_freq(device=device) else: inv_freq = self.inv_freq else: t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) inv_freq = self.inv_freq # Don't do einsum, it converts fp32 to fp16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward( self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, seqlen_offset: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim) kv: (batch, seqlen, 2, nheads, headdim) seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. Most commonly used in inference when we have KV cache. If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one should pass in max_seqlen, which will update the cos / sin cache up to that length. Apply rotary embedding *inplace* to qkv and / or kv. """ seqlen = qkv.shape[1] if max_seqlen is not None: self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) elif isinstance(seqlen_offset, int): self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) if kv is None: if self.scale is None: return apply_rotary_emb_qkv_( qkv, self._cos_cached, self._sin_cached, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, ) else: return apply_rotary_emb_qkv_( qkv, self._cos_cached, self._sin_cached, self._cos_k_cached, self._sin_k_cached, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, ) else: q = qkv q = apply_rotary_emb_func( q, self._cos_cached, self._sin_cached, interleaved=self.interleaved, inplace=True, seqlen_offsets=seqlen_offset, ) if self.scale is None: kv = apply_rotary_emb_kv_( kv, self._cos_cached, self._sin_cached, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, ) else: kv = apply_rotary_emb_kv_( kv, self._cos_k_cached, self._sin_k_cached, interleaved=self.interleaved, seqlen_offsets=seqlen_offset, ) return q, kv ================================================ FILE: nit/models/utils/pos_embeds/rope.py ================================================ # -------------------------------------------------------- # FiT: A Flexible Vision Transformer for Image Generation # # Based on the following repository # https://github.com/lucidrains/rotary-embedding-torch # https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope # https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37 # -------------------------------------------------------- import math from math import pi from typing import Optional, Any, Union, Tuple import torch from torch import nn from einops import rearrange, repeat from functools import lru_cache ################################################################################# # NTK Operations # ################################################################################# def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim-1) #Clamp values just in case def linear_ramp_mask(min, max, dim): if min == max: max += 0.001 #Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func def find_newbase_ntk(dim, base=10000, scale=1): # Base change formula return base * scale ** (dim / (dim-2)) def get_mscale(scale=torch.Tensor): # if scale <= 1: # return 1.0 # return 0.1 * math.log(scale) + 1.0 return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0) def get_proportion(L_test, L_train): L_test = L_test * 2 return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))) # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))) ################################################################################# # Rotate Q or K # ################################################################################# def rotate_half(x): x = rearrange(x, '... (d r) -> ... d r', r = 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) return rearrange(x, '... d r -> ... (d r)') ################################################################################# # Core Vision RoPE # ################################################################################# class VisionRotaryEmbedding(nn.Module): def __init__( self, head_dim: int, # embed dimension for each head custom_freqs: str = 'normal', theta: int = 10000, online_rope: bool = False, max_cached_len: int = 1024, max_pe_len_h: Optional[int] = None, max_pe_len_w: Optional[int] = None, decouple: bool = False, ori_max_pe_len: Optional[int] = None, ): super().__init__() dim = head_dim // 2 assert dim % 2 == 0 # accually, this is important self.dim = dim self.custom_freqs = custom_freqs.lower() self.theta = theta self.decouple = decouple self.ori_max_pe_len = ori_max_pe_len self.custom_freqs = custom_freqs.lower() if not online_rope: if self.custom_freqs in ['normal', 'scale1', 'scale2']: freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim)) else: if decouple: freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len) freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len) else: max_pe_len = max(max_pe_len_h, max_pe_len_w) freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len) self.register_buffer('freqs_h', freqs_h, persistent=False) self.register_buffer('freqs_w', freqs_w, persistent=False) if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None: attn_factor = 1.0 scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0) # dynamic scale self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len) self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2) freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h) freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2) self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False) freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w) freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2) self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False) def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len): # scaling operations for extrapolation assert isinstance(ori_max_pe_len, int) # scale = max_pe_len / ori_max_pe_len if not isinstance(max_pe_len, torch.Tensor): max_pe_len = torch.tensor(max_pe_len) scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0) # dynamic scale if self.custom_freqs == 'linear': # equal to position interpolation freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim)) elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2': freqs = 1. / torch.pow( find_newbase_ntk(dim, theta, scale).view(-1, 1), (torch.arange(0, dim, 2).to(scale).float() / dim) ).squeeze() elif self.custom_freqs == 'ntk-by-parts': #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) #Do not change unless there is a good reason for doing so! beta_0 = 1.25 beta_1 = 0.75 gamma_0 = 16 gamma_1 = 2 ntk_factor = 1 extrapolation_factor = 1 #Three RoPE extrapolation/interpolation methods freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) freqs_ntk = 1. / torch.pow( find_newbase_ntk(dim, theta, scale).view(-1, 1), (torch.arange(0, dim, 2).to(scale).float() / dim) ).squeeze() #Combine NTK and Linear low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask #Combine Extrapolation and NTK and Linear low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask elif self.custom_freqs == 'yarn': #Interpolation constants found experimentally for LLaMA (might not be totally optimal though) #Do not change unless there is a good reason for doing so! beta_fast = 32 beta_slow = 1 extrapolation_factor = 1 freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)) freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))) low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len) freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask else: raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!') return freqs def online_get_2d_rope_from_grid(self, grid, size): ''' grid: (B, 2, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] size: (B, 1, 2), h goes first and w goes last ''' size = size.squeeze() # (B, 1, 2) -> (B, 2) if self.decouple: size_h = size[:, 0] size_w = size[:, 1] freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len) freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len) else: size_max = torch.max(size[:, 0], size[:, 1]) freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len) freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :] freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :] freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) if self.custom_freqs == 'yarn': freqs_cos = freqs.cos() * self.mscale[:, None, None] freqs_sin = freqs.sin() * self.mscale[:, None, None] elif self.custom_freqs == 'ntk-aware-pro1': freqs_cos = freqs.cos() * self.proportion1[:, None, None] freqs_sin = freqs.sin() * self.proportion1[:, None, None] elif self.custom_freqs == 'ntk-aware-pro2': freqs_cos = freqs.cos() * self.proportion2[:, None, None] freqs_sin = freqs.sin() * self.proportion2[:, None, None] else: freqs_cos = freqs.cos() freqs_sin = freqs.sin() return freqs_cos, freqs_sin @lru_cache() def get_2d_rope_from_grid(self, grid): ''' grid: (B, 2, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] ''' freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h) freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w) freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) if self.custom_freqs == 'yarn': freqs_cos = freqs.cos() * self.mscale freqs_sin = freqs.sin() * self.mscale elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']: freqs_cos = freqs.cos() * self.proportion1 freqs_sin = freqs.sin() * self.proportion1 elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']: freqs_cos = freqs.cos() * self.proportion2 freqs_sin = freqs.sin() * self.proportion2 else: freqs_cos = freqs.cos() freqs_sin = freqs.sin() return freqs_cos, freqs_sin @lru_cache() def get_cached_2d_rope_from_grid(self, grid: torch.Tensor): ''' grid: (B, 2, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] ''' if len(grid.shape) == 3: # (B, 2, N) freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]] elif len(grid.shape) == 2: # (2, N) freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]] freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) if self.custom_freqs == 'yarn': freqs_cos = freqs.cos() * self.mscale freqs_sin = freqs.sin() * self.mscale elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']: freqs_cos = freqs.cos() * self.proportion1 freqs_sin = freqs.sin() * self.proportion1 elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']: freqs_cos = freqs.cos() * self.proportion2 freqs_sin = freqs.sin() * self.proportion2 else: freqs_cos = freqs.cos() freqs_sin = freqs.sin() return freqs_cos, freqs_sin @lru_cache() def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d rope formulation 2 ! ''' grid: (B, 3, N) N = H * W * T the first dimension represents width, and the second reprensents height, and the third reprensents time e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] ''' freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]+grid[:, 2]], self.freqs_h_cached[grid[:, 1]+grid[:, 2]] freqs = torch.cat([freqs_h, freqs_w], dim=-1) # (B, N, D) if self.custom_freqs == 'yarn': freqs_cos = freqs.cos() * self.mscale freqs_sin = freqs.sin() * self.mscale elif self.custom_freqs == 'ntk-aware-pro1': freqs_cos = freqs.cos() * self.proportion1 freqs_sin = freqs.sin() * self.proportion1 elif self.custom_freqs == 'ntk-aware-pro2': freqs_cos = freqs.cos() * self.proportion2 freqs_sin = freqs.sin() * self.proportion2 else: freqs_cos = freqs.cos() freqs_sin = freqs.sin() return freqs_cos, freqs_sin def forward(self, x, grid): ''' x: (B, n_head, N, D) grid: (B, 2, N) ''' # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid) # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) # using cache to accelerate, this is the same with the above codes: freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid) freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1) return x * freqs_cos + rotate_half(x) * freqs_sin ================================================ FILE: nit/models/utils/pos_embeds/sincos.py ================================================ ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# # modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py import torch import numpy as np from einops import rearrange import torch.nn.functional as F def get_2d_sincos_pos_embed(embed_dim, h, w, frac_coord_size=None, scale_ratio=1.0, cls_token=False, extra_tokens=0): """ args: h / w: int of the grid height / width frac_coord_size: if frac_coord_size != None: fractional coordinates for positional embedding is used else: absolute coordinates for positional embedding is used return: pos_embed: [h*w, embed_dim] or [1+h*w, embed_dim] (w/ or w/o cls_token) """ grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) grid = torch.meshgrid(grid_w, grid_h, indexing='xy') # here w goes first grid = torch.stack(grid, dim=0) grid = rearrange(grid, '... -> 1 ...') # (1, 2, h*w) pos_embed = get_2d_sincos_pos_embed_from_grid( grid, embed_dim, frac_coord_size, scale_ratio ) # 1, L, D if cls_token and extra_tokens > 0: pos_embed = torch.cat([torch.zeros((1, extra_tokens, embed_dim)), pos_embed], dim=1) return pos_embed def get_2d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0): ''' grid: (B, 2, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] frac_coord_size: if frac_coord_size != None: fractional coordinates for positional embedding is used else: absolute coordinates for positional embedding is used ''' assert embed_dim % 2 == 0 grid = grid.float() if frac_coord_size != None: assert isinstance(frac_coord_size, (int, float)) grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size else: grid_w, grid_h = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio # use half of dimensions to encode grid_h emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, embed_dim // 2) # (B, N, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, embed_dim // 2) # (B, N, D/2) emb = torch.cat([emb_h, emb_w], dim=-1) # (B, L, D) return emb def get_1d_sincos_pos_embed_from_grid(pos, embed_dim): """ embed_dim: output dimension for each position pos: a batch of list whose positions to be encoded: size (B, N) out: (B, N, D) """ assert embed_dim % 2 == 0 omega = torch.arange(embed_dim // 2, dtype=torch.float64) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) out = torch.einsum('BL,D->BLD', pos, omega.to(pos)) # (B, N, D/2), outer product emb_sin = torch.sin(out) # (B, N, D/2) emb_cos = torch.cos(out) # (B, N, D/2) emb = torch.cat([emb_sin, emb_cos], dim=-1) # (B, N, D) return emb def get_3d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0, time_dim=0): ''' grid: (B, 3, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] frac_coord_size: if frac_coord_size != None: fractional coordinates for positional embedding is used else: absolute coordinates for positional embedding is used ''' # assert embed_dim % 2 == 0 if time_dim == 0: assert embed_dim % 3 == 0 dim = embed_dim // 3 time_dim = dim else: assert (embed_dim - time_dim) % 2 == 0 dim = (embed_dim - time_dim) // 2 grid = grid.float() if frac_coord_size != None: assert isinstance(frac_coord_size, (int, float)) grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size else: grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio # use half of dimensions to encode grid_h emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim) # (B, N, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim) # (B, N, D/2) emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, time_dim) # (B, N, D/2) emb = torch.cat([emb_t, emb_h, emb_w], dim=-1) # (B, L, D) return emb def get_21d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0): ''' grid: (B, 3, N) N = H * W the first dimension represents width, and the second reprensents height e.g., [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.] [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.] frac_coord_size: if frac_coord_size != None: fractional coordinates for positional embedding is used else: absolute coordinates for positional embedding is used ''' assert embed_dim % 2 == 0 dim = embed_dim // 2 grid = grid.float() if frac_coord_size != None: assert isinstance(frac_coord_size, (int, float)) grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size else: grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio # use half of dimensions to encode grid_h emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim) emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim) emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim) emb = torch.cat([emb_h, emb_w], dim=-1) + emb_t # (B, L, D) return emb def get_time_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0): grid = grid.float() grid_t = grid[:, 0]*scale_ratio emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim) return emb_t ################################################################################# # interpolation # ################################################################################# def interpolate_sincos_pos_embed(embed_dim, ori_h, ori_w, tgt_h, tgt_w): from src.inf.models.dit import get_2d_sincos_pos_embed pos_embed = get_2d_sincos_pos_embed(embed_dim, ori_h, ori_w) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) pos_embed = rearrange(pos_embed, '1 (h w) d -> 1 d h w', h=ori_h, w=ori_w) pos_embed = F.interpolate(pos_embed, (tgt_h, tgt_w), mode='bilinear') pos_embed = rearrange(pos_embed, '1 d h w -> 1 (h w) d') return pos_embed def interpolate_sincos_pos_index(embed_dim, ori_h, ori_w, tgt_h, tgt_w): from src.inf.models.dit import get_2d_sincos_pos_embed_from_grid grid_h = np.arange(tgt_h, dtype=np.float32) * ori_h / tgt_h grid_w = np.arange(tgt_w, dtype=np.float32) * ori_w / tgt_w grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, tgt_h, tgt_w]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) return pos_embed ================================================ FILE: nit/schedulers/flow_matching/loss.py ================================================ import torch import numpy as np import torch.nn.functional as F def mean_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.mean(x, dim=list(range(1, len(x.size())))) def sum_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.sum(x, dim=list(range(1, len(x.size())))) class FlowMatchingLoss: def __init__( self, prediction='v', path_type="linear", weighting="uniform", encoders=[], accelerator=None, latents_scale=None, latents_bias=None, P_mean=0.0, P_std=1.0, sigma_data=1.0, unit_variance=False, ): self.prediction = prediction self.weighting = weighting self.path_type = path_type self.encoders = encoders self.accelerator = accelerator self.latents_scale = latents_scale self.latents_bias = latents_bias self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data self.unit_variance = unit_variance def interpolant(self, t): if self.path_type == "linear": alpha_t = 1 - t sigma_t = t d_alpha_t = -1 d_sigma_t = 1 elif self.path_type == "cosine": alpha_t = torch.cos(t * torch.pi / 2) sigma_t = torch.sin(t * torch.pi / 2) d_alpha_t = -torch.pi / 2 * torch.sin(t * torch.pi / 2) d_sigma_t = torch.pi / 2 * torch.cos(t * torch.pi / 2) elif self.path_type == 'triangle': alpha_t = torch.cos(t) sigma_t = torch.sin(t) d_alpha_t = -torch.sin(t) d_sigma_t = torch.cos(t) else: raise NotImplementedError() return alpha_t, sigma_t, d_alpha_t, d_sigma_t def __call__(self, model, batch_size, images, noises, model_kwargs=None, use_dir_loss=False, zs=[]): if model_kwargs == None: model_kwargs = {} # sample timestep according to log-normal distribution of sigmas following EDM rnd_normal = torch.randn((batch_size)) sigma = (rnd_normal * self.P_std + self.P_mean).exp() if self.path_type == "linear": # [0, 1] t = sigma / (1 + sigma) elif self.path_type == "cosine": # [0, 1] t = 2 / np.pi * torch.atan(sigma) elif self.path_type == 'triangle': # [0, pi/2] t = torch.atan(sigma / self.sigma_data) else: raise NotImplementedError t = t.to(device=images.device, dtype=images.dtype) time_input = t hw_list = model_kwargs['hw_list'] seqlens = hw_list[:, 0] * hw_list[:, 1] t = torch.cat([t[i].unsqueeze(0).repeat(seqlens[i], 1, 1, 1) for i in range(batch_size)], dim=0) alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(t) if self.unit_variance: model_input = alpha_t * images / self.sigma_data + sigma_t * noises else: model_input = alpha_t * images + sigma_t * noises if self.prediction == 'v': model_target = d_alpha_t * images + d_sigma_t * noises else: raise NotImplementedError() # TODO: add x or eps prediction model_kwargs['return_zs'] = True if self.unit_variance: model_output, zs_tilde = self.sigma_data * model(model_input, time_input, **model_kwargs) else: model_output, zs_tilde = model(model_input, time_input, **model_kwargs) denoising_loss = mean_flat((model_output - model_target) ** 2) denoising_loss = torch.nan_to_num(denoising_loss, nan=0, posinf=1e5, neginf=-1e5) loss = denoising_loss.mean() if use_dir_loss: directional_loss = mean_flat(1 - F.cosine_similarity(model_output, model_target, dim=1)) directional_loss = torch.nan_to_num(directional_loss, nan=0, posinf=1e5, neginf=-1e5) loss += directional_loss.mean() proj_loss = 0. if zs != [] and zs != None: for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): proj_loss += 1 - torch.cosine_similarity(z, z_tilde, dim=-1).mean() proj_loss = torch.nan_to_num(proj_loss, nan=0, posinf=1e5, neginf=-1e5) return loss, proj_loss ================================================ FILE: nit/schedulers/flow_matching/samplers_c2i.py ================================================ import torch import numpy as np def expand_t_like_x(t, x_cur, hw_list): """Function to reshape time t to broadcastable dimension of x Args: t: [batch_dim,], time vector x: [batch_dim,...], data point """ dims = [1] * (len(x_cur.size()) - 1) seqlens = hw_list[:, 0] * hw_list[:, 1] B = t.shape[0] t = torch.cat([t[i].unsqueeze(0).repeat(int(seqlens[i]), *dims) for i in range(B)], dim=0) return t def get_score_from_velocity(vt, xt, t, hw_list, path_type="linear"): """Wrapper function: transfrom velocity prediction model to score Args: velocity: [batch_dim, ...] shaped tensor; velocity model output x: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, xt, hw_list) if path_type == "linear": alpha_t, d_alpha_t = 1 - t, torch.ones_like(t, device=t.device) * -1 sigma_t, d_sigma_t = t, torch.ones_like(t, device=t.device) elif path_type == "cosine": alpha_t = torch.cos(t * np.pi / 2) sigma_t = torch.sin(t * np.pi / 2) d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) else: raise NotImplementedError mean = xt reverse_alpha_ratio = alpha_t / d_alpha_t var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t score = (reverse_alpha_ratio * vt - mean) / var return score def compute_diffusion(t_cur): return 2 * t_cur def euler_sampler( model, ag_model, latents, y, hw_list, num_steps=20, heun=False, cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", # not used, just for compatability ): # setup conditioning if cfg_scale > 1.0: y_null = torch.tensor([1000] * y.size(0), device=y.device) if ag_model != None: auto_guidance = True else: auto_guidance = False _dtype = latents.dtype t_steps = torch.linspace(1, 0, num_steps+1, dtype=torch.float64) x_next = latents.to(torch.float64) device = x_next.device with torch.no_grad(): for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): x_cur = x_next if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) hw_list_cur = torch.cat([hw_list, hw_list], dim=0) else: model_input = x_cur y_cur = y hw_list_cur = hw_list kwargs = dict(y=y_cur, hw_list=hw_list_cur) time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur d_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: if auto_guidance: kwargs = dict(y=y_null, hw_list=hw_list_cur) time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur d_cur_uncond = ag_model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond) else: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) x_next = x_cur + (t_next - t_cur) * d_cur if heun and (i < num_steps - 1): if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_next] * 2) y_cur = torch.cat([y, y_null], dim=0) hw_list_cur = torch.cat([hw_list, hw_list], dim=0) else: model_input = x_next y_cur = y hw_list_cur = hw_list kwargs = dict(y=y_cur, hw_list=hw_list_cur) time_input = torch.ones(y_cur.size(0)).to( device=model_input.device, dtype=torch.float64 ) * t_next d_prime = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: if auto_guidance: kwargs = dict(y=y_null, hw_list=hw_list_cur) time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_next d_prime_uncond = ag_model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) d_prime = d_prime_uncond + cfg_scale * (d_prime - d_prime_uncond) else: d_prime_cond, d_prime_uncond = d_prime.chunk(2) d_prime = d_prime_uncond + cfg_scale * (d_prime_cond - d_prime_uncond) x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime) return x_next def euler_maruyama_sampler( model, ag_model, latents, y, hw_list, num_steps=20, heun=False, # not used, just for compatability cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", ): # setup conditioning if cfg_scale > 1.0: y_null = torch.tensor([1000] * y.size(0), device=y.device) if ag_model != None: auto_guidance = True else: auto_guidance = False _dtype = latents.dtype t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64) t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)]) x_next = latents.to(torch.float64) device = x_next.device with torch.no_grad(): for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): dt = t_next - t_cur x_cur = x_next if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) hw_list_cur = torch.cat([hw_list, hw_list], dim=0) else: model_input = x_cur y_cur = y hw_list_cur = hw_list kwargs = dict(y=y_cur, hw_list=hw_list_cur) time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur diffusion = compute_diffusion(t_cur) eps_i = torch.randn_like(x_cur).to(device) deps = eps_i * torch.sqrt(torch.abs(dt)) # compute drift v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type) d_cur = v_cur - 0.5 * diffusion * s_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: if auto_guidance: kwargs = dict(y=y_null, hw_list=hw_list_cur) time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur diffusion = compute_diffusion(t_cur) eps_i = torch.randn_like(x_cur).to(device) deps = eps_i * torch.sqrt(torch.abs(dt)) # compute drift v_cur_uncond = ag_model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type) d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond) else: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps # last step t_cur, t_next = t_steps[-2], t_steps[-1] dt = t_next - t_cur x_cur = x_next if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) hw_list_cur = torch.cat([hw_list, hw_list], dim=0) else: model_input = x_cur y_cur = y hw_list_cur = hw_list kwargs = dict(y=y_cur, hw_list=hw_list_cur) time_input = torch.ones(y_cur.size(0)).to( device=device, dtype=torch.float64 ) * t_cur # compute drift v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type) diffusion = compute_diffusion(t_cur) d_cur = v_cur - 0.5 * diffusion * s_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: if auto_guidance: kwargs = dict(y=y_null, hw_list=hw_list_cur) time_input = torch.ones(y_null.size(0)).to( device=device, dtype=torch.float64 ) * t_cur # compute drift v_cur_uncond = ag_model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs ).to(torch.float64) s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type) diffusion = compute_diffusion(t_cur) d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond) else: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) mean_x = x_cur + dt * d_cur return mean_x ================================================ FILE: nit/utils/__init__.py ================================================ from .misc_utils import * from .train_utils import * from .eval_utils import * from .gpu_memory_monitor import * ================================================ FILE: nit/utils/deepspeed_zero_to_fp32.py ================================================ #!/usr/bin/env python # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # the future. Once extracted, the weights don't require DeepSpeed and can be used in any # application. # # example: python zero_to_fp32.py . pytorch_model.bin import argparse import torch import glob import math import os import re from collections import OrderedDict from dataclasses import dataclass # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with # DeepSpeed data structures it has to be available in the current python environment. from deepspeed.utils import logger from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) @dataclass class zero_model_state: buffers: dict() param_shapes: dict() shared_params: list ds_version: int frozen_param_shapes: dict() frozen_param_fragments: dict() debug = 0 # load to cpu device = torch.device('cpu') def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): ''' alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) ''' return [atoi(c) for c in re.split(r'(\d+)', text)] def get_model_state_file(checkpoint_dir, zero_stage): if not os.path.isdir(checkpoint_dir): raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") # there should be only one file if zero_stage <= 2: file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") elif zero_stage == 3: file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") if not os.path.exists(file): raise FileNotFoundError(f"can't find model states file at '{file}'") return file def get_checkpoint_files(checkpoint_dir, glob_pattern): # XXX: need to test that this simple glob rule works for multi-node setup too ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) if len(ckpt_files) == 0: raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") return ckpt_files def get_optim_files(checkpoint_dir): return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") def get_model_state_files(checkpoint_dir): return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") def parse_model_states(files): zero_model_states = [] for file in files: state_dict = torch.load(file, map_location=device) if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") buffer_names = state_dict[BUFFER_NAMES] if debug: print("Found buffers:", buffer_names) # recover just the buffers while restoring them to fp32 if they were saved in fp16 buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} param_shapes = state_dict[PARAM_SHAPES] # collect parameters that are included in param_shapes param_names = [] for s in param_shapes: for name in s.keys(): param_names.append(name) # update with frozen parameters frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) if frozen_param_shapes is not None: if debug: print(f"Found frozen_param_shapes: {frozen_param_shapes}") param_names += list(frozen_param_shapes.keys()) # handle shared params shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] ds_version = state_dict.get(DS_VERSION, None) frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) z_model_state = zero_model_state(buffers=buffers, param_shapes=param_shapes, shared_params=shared_params, ds_version=ds_version, frozen_param_shapes=frozen_param_shapes, frozen_param_fragments=frozen_param_fragments) zero_model_states.append(z_model_state) return zero_model_states def parse_optim_states(files, ds_checkpoint_dir): total_files = len(files) state_dicts = [] for f in files: state_dict = torch.load(f, map_location=device) # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights # and also handle the case where it was already removed by another helper script state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) state_dicts.append(state_dict) if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: raise ValueError(f"{files[0]} is not a zero checkpoint") zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] # For ZeRO-2 each param group can have different partition_count as data parallelism for expert # parameters can be different from data parallelism for non-expert parameters. So we can just # use the max of the partition_count to get the dp world_size. if type(world_size) is list: world_size = max(world_size) if world_size != total_files: raise ValueError( f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." ) # the groups are named differently in each stage if zero_stage <= 2: fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS elif zero_stage == 3: fp32_groups_key = FP32_FLAT_GROUPS else: raise ValueError(f"unknown zero stage {zero_stage}") if zero_stage <= 2: fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] elif zero_stage == 3: # if there is more than one param group, there will be multiple flattened tensors - one # flattened tensor per group - for simplicity merge them into a single tensor # # XXX: could make the script more memory efficient for when there are multiple groups - it # will require matching the sub-lists of param_shapes for each param group flattened tensor fp32_flat_groups = [ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) ] return zero_stage, world_size, fp32_flat_groups def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint Args: - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) """ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") optim_files = get_optim_files(ds_checkpoint_dir) zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") model_files = get_model_state_files(ds_checkpoint_dir) zero_model_states = parse_model_states(model_files) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) elif zero_stage == 3: return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: return frozen_param_shapes = zero_model_states[0].frozen_param_shapes frozen_param_fragments = zero_model_states[0].frozen_param_fragments if debug: num_elem = sum(s.numel() for s in frozen_param_shapes.values()) print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') wanted_params = len(frozen_param_shapes) wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) print(f'Frozen params: Have {avail_numel} numels to process.') print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') total_params = 0 total_numel = 0 for name, shape in frozen_param_shapes.items(): total_params += 1 unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel state_dict[name] = frozen_param_fragments[name] if debug: print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") def _has_callable(obj, fn): attr = getattr(obj, fn, None) return callable(attr) def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes # Reconstruction protocol: # # XXX: document this if debug: for i in range(world_size): for j in range(len(fp32_flat_groups[0])): print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") # XXX: memory usage doubles here (zero2) num_param_groups = len(fp32_flat_groups[0]) merged_single_partition_of_fp32_groups = [] for i in range(num_param_groups): merged_partitions = [sd[i] for sd in fp32_flat_groups] full_single_fp32_vector = torch.cat(merged_partitions, 0) merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) avail_numel = sum( [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) if debug: wanted_params = sum([len(shapes) for shapes in param_shapes]) wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) # not asserting if there is a mismatch due to possible padding print(f"Have {avail_numel} numels to process.") print(f"Need {wanted_numel} numels in {wanted_params} params.") # params # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # out-of-core computing solution total_numel = 0 total_params = 0 for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): offset = 0 avail_numel = full_single_fp32_vector.numel() for name, shape in shapes.items(): unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape) total_numel += unpartitioned_numel total_params += 1 if debug: print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) offset += unpartitioned_numel # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex # paddings performed in the code it's almost impossible to predict the exact numbers w/o the # live optimizer object, so we are checking that the numbers are within the right range align_to = 2 * world_size def zero2_align(x): return align_to * math.ceil(x / align_to) if debug: print(f"original offset={offset}, avail_numel={avail_numel}") offset = zero2_align(offset) avail_numel = zero2_align(avail_numel) if debug: print(f"aligned offset={offset}, avail_numel={avail_numel}") # Sanity check if offset != avail_numel: raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) if debug: print(f"added {len(buffers)} buffers") if not exclude_frozen_parameters: _zero2_merge_frozen_params(state_dict, zero_model_states) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) # recover shared parameters for pair in zero_model_states[0].shared_params: if pair[1] in state_dict: state_dict[pair[0]] = state_dict[pair[1]] return state_dict def zero3_partitioned_param_info(unpartitioned_numel, world_size): remainder = unpartitioned_numel % world_size padding_numel = (world_size - remainder) if remainder else 0 partitioned_numel = math.ceil(unpartitioned_numel / world_size) return partitioned_numel, padding_numel def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: return if debug: for i in range(world_size): num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') frozen_param_shapes = zero_model_states[0].frozen_param_shapes wanted_params = len(frozen_param_shapes) wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size print(f'Frozen params: Have {avail_numel} numels to process.') print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') total_params = 0 total_numel = 0 for name, shape in zero_model_states[0].frozen_param_shapes.items(): total_params += 1 unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: print( f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): param_shapes = zero_model_states[0].param_shapes avail_numel = fp32_flat_groups[0].numel() * world_size # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each # param, re-consolidating each param, while dealing with padding if any # merge list of dicts, preserving order param_shapes = {k: v for d in param_shapes for k, v in d.items()} if debug: for i in range(world_size): print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") wanted_params = len(param_shapes) wanted_numel = sum(shape.numel() for shape in param_shapes.values()) # not asserting if there is a mismatch due to possible padding avail_numel = fp32_flat_groups[0].numel() * world_size print(f"Trainable params: Have {avail_numel} numels to process.") print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") # params # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support # out-of-core computing solution offset = 0 total_numel = 0 total_params = 0 for name, shape in param_shapes.items(): unpartitioned_numel = shape.numel() total_numel += unpartitioned_numel total_params += 1 partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) if debug: print( f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" ) # XXX: memory usage doubles here state_dict[name] = torch.cat( tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), 0).narrow(0, 0, unpartitioned_numel).view(shape) offset += partitioned_numel offset *= world_size # Sanity check if offset != avail_numel: raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, exclude_frozen_parameters): state_dict = OrderedDict() # buffers buffers = zero_model_states[0].buffers state_dict.update(buffers) if debug: print(f"added {len(buffers)} buffers") if not exclude_frozen_parameters: _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) # recover shared parameters for pair in zero_model_states[0].shared_params: if pair[1] in state_dict: state_dict[pair[0]] = state_dict[pair[1]] return state_dict def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example via a model hub. Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters Returns: - pytorch ``state_dict`` Note: this approach may not work if your application doesn't have sufficient free CPU memory and you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with the checkpoint. A typical usage might be :: from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint # do the training and checkpoint saving state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu model = model.cpu() # move to cpu model.load_state_dict(state_dict) # submit to model hub or save the model to share with others In this example the ``model`` will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. """ if tag is None: latest_path = os.path.join(checkpoint_dir, 'latest') if os.path.isfile(latest_path): with open(latest_path, 'r') as fd: tag = fd.read().strip() else: raise ValueError(f"Unable to find 'latest' file at {latest_path}") ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. Args: - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - ``exclude_frozen_parameters``: exclude frozen parameters """ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) print(f"Saving fp32 state dict to {output_file}") torch.save(state_dict, output_file) def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): """ 1. Put the provided model to cpu 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` 3. Load it into the provided model Args: - ``model``: the model object to update - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` Returns: - ``model`: modified model Make sure you have plenty of CPU memory available before you call this function. If you don't have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it conveniently placed for you in the checkpoint folder. A typical usage might be :: from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) # submit to model hub or save the model to share with others Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context of the same application. i.e. you will need to re-initialize the deepspeed engine, since ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. """ logger.info(f"Extracting fp32 weights") state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) logger.info(f"Overwriting model with fp32 weights") model = model.cpu() model.load_state_dict(state_dict, strict=False) return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12") parser.add_argument( "output_file", type=str, help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") parser.add_argument("-t", "--tag", type=str, default=None, help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag, exclude_frozen_parameters=args.exclude_frozen_parameters) ================================================ FILE: nit/utils/ema.py ================================================ import torch from collections import OrderedDict from copy import deepcopy @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ if hasattr(model, 'module'): model = model.module if hasattr(ema_model, 'module'): ema_model = ema_model.module ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) ================================================ FILE: nit/utils/eval_utils.py ================================================ from PIL import Image import numpy as np from tqdm import tqdm import torch import re import os from safetensors.torch import load_file def create_npz_from_sample_folder(sample_dir, num=50_000): """ Builds a single .npz file from a folder of .png samples. """ samples = [] imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0])) print(len(imgs)) assert len(imgs) >= num for i in tqdm(range(num), desc="Building .npz file from samples"): sample_pil = Image.open(f"{sample_dir}/{imgs[i]}") sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) npz_path = f"{sample_dir}.npz" np.savez(npz_path, arr_0=samples) print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path def init_from_ckpt( model, checkpoint_dir, ignore_keys=None, verbose=False ) -> None: if checkpoint_dir.endswith(".safetensors"): model_state_dict=load_file(checkpoint_dir, device='cpu') else: model_state_dict=torch.load(checkpoint_dir, map_location="cpu") model_new_ckpt=dict() for i in model_state_dict.keys(): model_new_ckpt[i] = model_state_dict[i] keys = list(model_new_ckpt.keys()) for k in keys: if ignore_keys: for ik in ignore_keys: if ik in k: print("Deleting key {} from state_dict.".format(k)) del model_new_ckpt[k] missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False) if verbose: print( f"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") if verbose: print("") def none_or_str(value): if value == 'None': return None return value def parse_sde_args(parser): group = parser.add_argument_group("SDE arguments") group.add_argument("--sde-sampling-method", type=str, default="Euler", choices=["Euler", "Heun"]) group.add_argument("--diffusion-form", type=str, default="sigma", \ choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\ help="form of diffusion coefficient in the SDE") group.add_argument("--diffusion-norm", type=float, default=1.0) group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\ help="form of last step taken in the SDE") group.add_argument("--last-step-size", type=float, default=0.04, \ help="size of the last step taken") def parse_ode_args(parser): group = parser.add_argument_group("ODE arguments") group.add_argument("--ode-sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq") group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance") group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance") group.add_argument("--reverse", action="store_true") group.add_argument("--likelihood", action="store_true") # ode solvers: # - Adaptive-step: # - dopri8 Runge-Kutta 7(8) of Dormand-Prince-Shampine # - dopri5 Runge-Kutta 4(5) of Dormand-Prince [default]. # - bosh3 Runge-Kutta 2(3) of Bogacki-Shampine # - adaptive_heun Runge-Kutta 1(2) # - Fixed-step: # - euler Euler method. # - midpoint Midpoint method. # - rk4 Fourth-order Runge-Kutta with 3/8 rule. # - explicit_adams Explicit Adams. # - implicit_adams Implicit Adams. ================================================ FILE: nit/utils/freeze.py ================================================ from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def freeze_model(model, trainable_modules={}, verbose=False): logger.info("Start freeze") for name, param in model.named_parameters(): param.requires_grad = False if verbose: logger.info("freeze moduel: "+str(name)) for trainable_module_name in trainable_modules: if trainable_module_name in name: param.requires_grad = True if verbose: logger.info("unfreeze moduel: "+str(name)) break logger.info("End freeze") params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()] params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()] logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M") logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M") return ================================================ FILE: nit/utils/gpu_memory_monitor.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import os from collections import namedtuple from datetime import datetime from typing import Any, Dict, Optional import torch # named tuple for passing GPU memory stats for logging GPUMemStats = namedtuple( "GPUMemStats", [ "max_active_gib", "max_active_pct", "max_reserved_gib", "max_reserved_pct", "num_alloc_retries", "num_ooms", ], ) class GPUMemoryMonitor: def __init__(self, logger, device: str = "cuda:0"): self.device = torch.device(device) # device object self.device_name = torch.cuda.get_device_name(self.device) self.device_index = torch.cuda.current_device() self.device_capacity = torch.cuda.get_device_properties( self.device ).total_memory self.device_capacity_gib = self._to_gib(self.device_capacity) self.logger = logger torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() def _to_gib(self, memory_in_bytes): # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 _gib_in_bytes = 1024 * 1024 * 1024 memory_in_gib = memory_in_bytes / _gib_in_bytes return memory_in_gib def _to_pct(self, memory): return 100 * memory / self.device_capacity def get_peak_stats(self): cuda_info = torch.cuda.memory_stats(self.device) max_active = cuda_info["active_bytes.all.peak"] max_active_gib = self._to_gib(max_active) max_active_pct = self._to_pct(max_active) max_reserved = cuda_info["reserved_bytes.all.peak"] max_reserved_gib = self._to_gib(max_reserved) max_reserved_pct = self._to_pct(max_reserved) num_retries = cuda_info["num_alloc_retries"] num_ooms = cuda_info["num_ooms"] if num_retries > 0: self.logger.warning(f"{num_retries} CUDA memory allocation retries.") if num_ooms > 0: self.logger.warning(f"{num_ooms} CUDA OOM errors thrown.") return GPUMemStats( max_active_gib, max_active_pct, max_reserved_gib, max_reserved_pct, num_retries, num_ooms, ) def reset_peak_stats(self): torch.cuda.reset_peak_memory_stats() def build_gpu_memory_monitor(logger): gpu_memory_monitor = GPUMemoryMonitor(logger, "cuda") logger.info( f"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) " f"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory" ) return gpu_memory_monitor ================================================ FILE: nit/utils/lr_scheduler.py ================================================ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch optimization for diffusion models.""" import math from enum import Enum from typing import Optional, Union from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR class SchedulerType(Enum): LINEAR = "linear" COSINE = "cosine" COSINE_WITH_RESTARTS = "cosine_with_restarts" POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" PIECEWISE_CONSTANT = "piecewise_constant" WARMDUP_STABLE_DECAY = "warmup_stable_decay" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) def get_constant_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, div_factor: int = 1e-4, last_epoch: int = -1 ): def lr_lambda(current_step): # 0,y0 step,y1 #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1 if current_step < num_warmup_steps: return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch) def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): """ Create a schedule with a constant learning rate, using the learning rate set in optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. step_rules (`string`): The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ rules_dict = {} rule_list = step_rules.split(",") for rule_str in rule_list[:-1]: value_str, steps_str = rule_str.split(":") steps = int(steps_str) value = float(value_str) rules_dict[steps] = value last_lr_multiple = float(rule_list[-1]) def create_rules_function(rules_dict, last_lr_multiple): def rule_func(steps: int) -> float: sorted_steps = sorted(rules_dict.keys()) for i, sorted_step in enumerate(sorted_steps): if steps < sorted_step: return rules_dict[sorted_steps[i]] return last_lr_multiple return rule_func rules_func = create_rules_function(rules_dict, last_lr_multiple) return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) return max( 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) ) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_cosine_with_hard_restarts_schedule_with_warmup( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_cycles (`int`, *optional*, defaults to 1): The number of hard restarts to use. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ def lr_lambda(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) if progress >= 1.0: return 0.0 return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) return LambdaLR(optimizer, lr_lambda, last_epoch) def get_polynomial_decay_schedule_with_warmup( optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1 ): """ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. lr_end (`float`, *optional*, defaults to 1e-7): The end LR. power (`float`, *optional*, defaults to 1.0): Power factor. last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ lr_init = optimizer.defaults["lr"] if not (lr_init > lr_end): raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") def lr_lambda(current_step: int): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step > num_training_steps: return lr_end / lr_init # as LambdaLR multiplies by lr_init else: lr_range = lr_init - lr_end decay_steps = num_training_steps - num_warmup_steps pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps decay = lr_range * pct_remaining**power + lr_end return decay / lr_init # as LambdaLR multiplies by lr_init return LambdaLR(optimizer, lr_lambda, last_epoch) def get_constant_schedule_with_warmup_and_decay( optimizer: Optimizer, num_warmup_steps: int, num_decay_steps: int, decay_T: int = 50000, div_factor: int = 1e-4, last_epoch: int = -1 ): def lr_lambda(current_step): # 0,y0 step,y1 #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1 if current_step < num_warmup_steps: return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor if current_step > num_decay_steps: return 0.5 ** ((current_step - num_decay_steps) / decay_T) return 1.0 return LambdaLR(optimizer, lr_lambda, last_epoch) TYPE_TO_SCHEDULER_FUNCTION = { SchedulerType.LINEAR: get_linear_schedule_with_warmup, SchedulerType.COSINE: get_cosine_schedule_with_warmup, SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, SchedulerType.WARMDUP_STABLE_DECAY: get_constant_schedule_with_warmup_and_decay } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_decay_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, decay_T: Optional[int] = 50000, power: float = 1.0, last_epoch: int = -1, ): """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. step_rules (`str`, *optional*): A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_decay_steps (`int`, *optional*): The number of decay steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles (`int`, *optional*): The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. power (`float`, *optional*, defaults to 1.0): Power factor. See `POLYNOMIAL` scheduler decay_T (`int`, *optional*, defaults to 50000): Power factor. See `POLYNOMIAL` scheduler last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) if name == SchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) if name == SchedulerType.WARMDUP_STABLE_DECAY: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_decay_steps=num_decay_steps, decay_T=decay_T, last_epoch=last_epoch) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles, last_epoch=last_epoch, ) if name == SchedulerType.POLYNOMIAL: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, last_epoch=last_epoch, ) return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch ) ================================================ FILE: nit/utils/misc_utils.py ================================================ import functools import importlib import os import wandb import fsspec import numpy as np import torch from dataclasses import dataclass from functools import partial from inspect import isfunction from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors def get_dtype(str_dtype): if str_dtype == 'fp16': return torch.float16 elif str_dtype == 'bf16': return torch.bfloat16 else: return torch.float32 def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # Check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple if type(t) == tuple: return t[0] else: pass except: pass return s def is_power_of_two(n): """ chat.openai.com/chat Return True if n is a power of 2, otherwise return False. The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. """ if n <= 0: return False return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( text_seq[start : start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) return path def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_model_from_config(config, ckpt, verbose=True, freeze=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) elif ckpt.endswith("bin"): sd = torch.load(ckpt, map_location="cpu") else: raise NotImplementedError model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model def format_number(num): num = float(num) num /= 1000.0 return '{:.0f}{}'.format(num, 'k') def get_num_params(model: torch.nn.ModuleList) -> int: num_params = sum(p.numel() for p in model.parameters()) return num_params def get_num_flop_per_token(num_params, model_config, seq_len) -> int: l, h, q, t = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, seq_len, ) # Reasoning behind the factor of 12 for the self-attention part of the formula: # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) # 2. the flash attention does 1 more matmul recomputation in the backward # but recomputation should not be counted in calculating MFU (+0) # 3. each matmul performs 1 multiplication and 1 addition (*2) # 4. we follow the convention and do not account for sparsity in causal attention flop_per_token = 6 * num_params + 12 * l * h * q * t return flop_per_token def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int: l, h, q = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, ) # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6) # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2) # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len return flop_per_sequence # hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU def get_peak_flops(device_name: str) -> int: if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ return 312e12 elif "H100" in device_name: # data from https://www.nvidia.com/en-us/data-center/h100/ # NOTE: Specifications are one-half lower without sparsity. if "NVL" in device_name: return 1979e12 elif "PCIe" in device_name: return 756e12 else: # for SXM and other variants return 989e12 else: # for other GPU types, assume A100 return 312e12 @dataclass(frozen=True) class Color: black = "\033[30m" red = "\033[31m" green = "\033[32m" yellow = "\033[33m" blue = "\033[34m" magenta = "\033[35m" cyan = "\033[36m" white = "\033[37m" reset = "\033[39m" @dataclass(frozen=True) class NoColor: black = "" red = "" green = "" yellow = "" blue = "" magenta = "" cyan = "" white = "" reset = "" ================================================ FILE: nit/utils/model_utils.py ================================================ import os import torch from transformers import T5EncoderModel, AutoModelForCausalLM, AutoTokenizer # dc-ae def dc_ae_encode(dc_ae, images): with torch.no_grad(): latents = dc_ae.encode(images).latent * dc_ae.config.scaling_factor return latents def dc_ae_decode(dc_ae, latents): with torch.no_grad(): z = latents / dc_ae.config.scaling_factor if dc_ae.use_slicing and z.size(0) > 1: decoded_slices = [dc_ae._decode(z_slice) for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: decoded = dc_ae._decode(z) images = decoded # decoded images return images # sd-vae def sd_vae_encode(sd_vae, images): with torch.no_grad(): z = sd_vae.encode(images) if isinstance(z, dict): z=z.latent_dist.sample() z = sd_vae.config.scaling_factor * z return z def sd_vae_decode(sd_vae, latents): with torch.no_grad(): z = 1.0 / sd_vae.config.scaling_factor * latents out = sd_vae.decode(z) if isinstance(out, dict): out=out.sample return out # load text-encoder def load_text_encoder(text_encoder_dir, device, weight_dtype): os.environ["TOKENIZERS_PARALLELISM"] = "true" tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir) if 'gemma' in text_encoder_dir: tokenizer.padding_side = "right" text_encoder = AutoModelForCausalLM.from_pretrained( text_encoder_dir, attn_implementation="flash_attention_2", device_map='cpu', torch_dtype=weight_dtype ).get_decoder() elif 't5' in text_encoder_dir: text_encoder = T5EncoderModel.from_pretrained( text_encoder_dir, attn_implementation="sdpa", device_map='cpu', torch_dtype=weight_dtype ) else: raise NotImplementedError text_encoder.requires_grad_(False) text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype) return text_encoder, tokenizer def encode_prompt(tokenizer, text_encoder, device, weight_dtype, captions, use_last_hidden_state, max_seq_length=256): text_inputs = tokenizer( captions, padding='max_length', max_length=max_seq_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_masks = text_inputs.attention_mask.to(device) with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): results = text_encoder( input_ids=text_input_ids, attention_mask=prompt_masks, output_hidden_states=True, ) if use_last_hidden_state: prompt_embeds = results.last_hidden_state else: # from Imagen paper prompt_embeds = results.hidden_states[-2] return prompt_embeds, prompt_masks def prepare_null_cap_feat_mask(text_encoder_type, device, weight_dtype, use_last_hidden_state, max_seq_length=256): text_encoder, tokenizer = load_text_encoder( text_encoder_dir=text_encoder_type, device=device, weight_dtype=weight_dtype ) null_cap_features, null_cap_mask = encode_prompt( tokenizer, text_encoder, device, weight_dtype, "", use_last_hidden_state, max_seq_length=max_seq_length ) return null_cap_features, null_cap_mask ================================================ FILE: nit/utils/train_utils.py ================================================ import torch from collections import OrderedDict from copy import deepcopy from diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name def freeze_model(model, trainable_modules={}, verbose=False): logger.info("Start freeze") for name, param in model.named_parameters(): param.requires_grad = False if verbose: logger.info("freeze moduel: "+str(name)) for trainable_module_name in trainable_modules: if trainable_module_name in name: param.requires_grad = True if verbose: logger.info("unfreeze moduel: "+str(name)) break logger.info("End freeze") params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()] params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()] logger.info(f"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M") logger.info(f"Freeze Module Parameters: {sum(params_freeze) / 1e6} M") return @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ if hasattr(model, 'module'): model = model.module if hasattr(ema_model, 'module'): ema_model = ema_model.module ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def log_validation(model): pass ================================================ FILE: nit/utils/util.py ================================================ import functools import importlib import os import wandb import fsspec import numpy as np import torch from dataclasses import dataclass from functools import partial from inspect import isfunction from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def get_string_from_tuple(s): try: # Check if the string starts and ends with parentheses if s[0] == "(" and s[-1] == ")": # Convert the string to a tuple t = eval(s) # Check if the type of t is tuple if type(t) == tuple: return t[0] else: pass except: pass return s def is_power_of_two(n): """ chat.openai.com/chat Return True if n is a power of 2, otherwise return False. The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. """ if n <= 0: return False return (n & (n - 1)) == 0 def autocast(f, enabled=True): def do_autocast(*args, **kwargs): with torch.cuda.amp.autocast( enabled=enabled, dtype=torch.get_autocast_gpu_dtype(), cache_enabled=torch.is_autocast_cache_enabled(), ): return f(*args, **kwargs) return do_autocast def load_partial_from_config(config): return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) nc = int(40 * (wh[0] / 256)) if isinstance(xc[bi], list): text_seq = xc[bi][0] else: text_seq = xc[bi] lines = "\n".join( text_seq[start : start + nc] for start in range(0, len(text_seq), nc) ) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) return NewCls def make_path_absolute(path): fs, p = fsspec.core.url_to_fs(path) if fs.protocol == "file": return os.path.abspath(p) return path def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def isheatmap(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 2 def isneighbors(x): if not isinstance(x, torch.Tensor): return False return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) def exists(x): return x is not None def expand_dims_like(x, y): while x.dim() != y.dim(): x = x.unsqueeze(-1) return x def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False, invalidate_cache=True): module, cls = string.rsplit(".", 1) if invalidate_cache: importlib.invalidate_caches() if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def load_model_from_config(config, ckpt, verbose=True, freeze=True): print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) elif ckpt.endswith("bin"): sd = torch.load(ckpt, map_location="cpu") else: raise NotImplementedError model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model def format_number(num): num = float(num) num /= 1000.0 return '{:.0f}{}'.format(num, 'k') def get_num_params(model: torch.nn.ModuleList) -> int: num_params = sum(p.numel() for p in model.parameters()) return num_params def get_num_flop_per_token(num_params, model_config, seq_len) -> int: l, h, q, t = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, seq_len, ) # Reasoning behind the factor of 12 for the self-attention part of the formula: # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) # 2. the flash attention does 1 more matmul recomputation in the backward # but recomputation should not be counted in calculating MFU (+0) # 3. each matmul performs 1 multiplication and 1 addition (*2) # 4. we follow the convention and do not account for sparsity in causal attention flop_per_token = 6 * num_params + 12 * l * h * q * t return flop_per_token def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int: l, h, q = ( model_config.n_layers, model_config.n_heads, model_config.dim // model_config.n_heads, ) # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6) # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2) # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len return flop_per_sequence # hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU def get_peak_flops(device_name: str) -> int: if "A100" in device_name: # data from https://www.nvidia.com/en-us/data-center/a100/ return 312e12 elif "H100" in device_name: # data from https://www.nvidia.com/en-us/data-center/h100/ # NOTE: Specifications are one-half lower without sparsity. if "NVL" in device_name: return 1979e12 elif "PCIe" in device_name: return 756e12 else: # for SXM and other variants return 989e12 else: # for other GPU types, assume A100 return 312e12 @dataclass(frozen=True) class Color: black = "\033[30m" red = "\033[31m" green = "\033[32m" yellow = "\033[33m" blue = "\033[34m" magenta = "\033[35m" cyan = "\033[36m" white = "\033[37m" reset = "\033[39m" @dataclass(frozen=True) class NoColor: black = "" red = "" green = "" yellow = "" blue = "" magenta = "" cyan = "" white = "" reset = "" ================================================ FILE: nit/utils/video_utils.py ================================================ import os import cv2 import numpy as np from PIL import Image def save_video_as_mp4(video_array, fps, output_path): # video_array: TCHW (RGB) height, width = video_array.shape[2], video_array.shape[3] fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for t in range(video_array.shape[0]): frame = video_array[t].transpose(1, 2, 0) frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # RGB->BGR out.write(cv2.convertScaleAbs(frame)) out.release() def save_video_as_png(video_array, output_path): os.makedirs(output_path, exist_ok=True) # video_array: TCHW (RGB) for i, sample in enumerate(video_array): sample = np.transpose(sample, (1, 2, 0)) Image.fromarray(sample).save( # HWC os.path.join(output_path, f"{i:06d}.png") ) ================================================ FILE: nit/utils/warp_pos_idx.py ================================================ import torch import random from typing import Optional, Union def warp_pos_idx_from_grid( grid: torch.Tensor, shift: Optional[int] = 0, scale: Optional[str] = None, max_len: Optional[Union[int, float]]=None ): ''' grid: the 2-D positional index to be warped (B, 2, D) shift: the max shift value for the positional indices scale: the scale scheme for warping positional indices max_len: the max scale length ''' grid[:, 0] = warp_pos_idx(grid[:, 0], shift, scale, max_len) grid[:, 1] = warp_pos_idx(grid[:, 1], shift, scale, max_len) return grid def warp_pos_idx( pos_idx: torch.Tensor, shift: Optional[int] = 0, scale: Optional[str] = None, max_len: Optional[Union[int, float]]=None ): ''' pos_idx: the 1-D positional index to be warped (B, D) shift: the max shift value for the positional indices scale: the scale scheme for warping positional indices max_len: the max scale length ''' if scale != None: assert isinstance(scale, str) and isinstance(max_len, (int, float)) if scale.lower() == 'linear': pos_idx = max_len * (pos_idx / pos_idx.max()) elif scale.lower() == 'sqrt': pos_idx = max_len * torch.sqrt(pos_idx / max_len) elif scale.lower() in ['sine', 'cosine', 'sin', 'cos']: pos_idx = max_len * torch.sin(pos_idx / max_len * (torch.pi/2)) else: raise NotImplementedError('Only support linear, cosine, beta scale scheme for warping') pos_idx = pos_idx + random.randint(0, shift) return pos_idx ================================================ FILE: projects/evaluate/adm_evaluator.py ================================================ import argparse import io import os import random import warnings import zipfile from abc import ABC, abstractmethod from contextlib import contextmanager from functools import partial from multiprocessing import cpu_count from multiprocessing.pool import ThreadPool from typing import Iterable, Optional, Tuple from PIL import Image import numpy as np import requests import tensorflow.compat.v1 as tf from scipy import linalg from tqdm.auto import tqdm INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" INCEPTION_V3_PATH = "checkpoints/classify_image_graph_def.pb" FID_POOL_NAME = "pool_3:0" FID_SPATIAL_NAME = "mixed_6/conv:0" def main(): parser = argparse.ArgumentParser() parser.add_argument("ref_batch", help="path to reference batch npz file") parser.add_argument("sample_batch", help="path to sample batch npz file") args = parser.parse_args() config = tf.ConfigProto( allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph ) config.gpu_options.allow_growth = True evaluator = Evaluator(tf.Session(config=config)) print("warming up TensorFlow...") # This will cause TF to print a bunch of verbose stuff now rather # than after the next print(), to help prevent confusion. evaluator.warmup() print("computing reference batch activations...") ref_acts = evaluator.read_activations(args.ref_batch) print("computing/reading reference batch statistics...") ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) print("computing sample batch activations...") sample_acts = evaluator.read_activations(args.sample_batch) print("computing/reading sample batch statistics...") sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) print("Computing evaluations...") print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) print("FID:", sample_stats.frechet_distance(ref_stats)) print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) print("Precision:", prec) print("Recall:", recall) class InvalidFIDException(Exception): pass class FIDStatistics: def __init__(self, mu: np.ndarray, sigma: np.ndarray): self.mu = mu self.sigma = sigma def frechet_distance(self, other, eps=1e-6): """ Compute the Frechet distance between two sets of statistics. """ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 mu1, sigma1 = self.mu, self.sigma mu2, sigma2 = other.mu, other.sigma mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) assert ( mu1.shape == mu2.shape ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" assert ( sigma1.shape == sigma2.shape ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" diff = mu1 - mu2 # product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): msg = ( "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps ) warnings.warn(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean class Evaluator: def __init__( self, session, batch_size=64, softmax_batch_size=512, ): self.sess = session self.batch_size = batch_size self.softmax_batch_size = softmax_batch_size self.manifold_estimator = ManifoldEstimator(session) with self.sess.graph.as_default(): self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) self.softmax = _create_softmax_graph(self.softmax_input) def warmup(self): self.compute_activations(np.zeros([1, 8, 64, 64, 3])) def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: if npz_path.endswith('.npz'): with open_npz_array(npz_path, "arr_0") as reader: return self.compute_activations(reader.read_batches(self.batch_size)) else: preds = [] spatial_preds = [] files = os.listdir(npz_path) run_iter = int(len(files) / self.batch_size) for i in tqdm(range(run_iter)): samples = [] for file in files[i*self.batch_size: (i+1)*self.batch_size]: try: sample_pil = Image.open(os.path.join(npz_path, file)) sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) except: print('wrong file', os.path.join(npz_path, file)) samples = np.stack(samples) samples = samples.astype(np.float32) pred, spatial_pred = self.sess.run( [self.pool_features, self.spatial_features], {self.image_input: samples} ) preds.append(pred.reshape([pred.shape[0], -1])) spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) return ( np.concatenate(preds, axis=0), np.concatenate(spatial_preds, axis=0), ) def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: """ Compute image features for downstream evals. :param batches: a iterator over NHWC numpy arrays in [0, 255]. :return: a tuple of numpy arrays of shape [N x X], where X is a feature dimension. The tuple is (pool_3, spatial). """ preds = [] spatial_preds = [] for batch in tqdm(batches): batch = batch.astype(np.float32) pred, spatial_pred = self.sess.run( [self.pool_features, self.spatial_features], {self.image_input: batch} ) preds.append(pred.reshape([pred.shape[0], -1])) spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) return ( np.concatenate(preds, axis=0), np.concatenate(spatial_preds, axis=0), ) def read_statistics( self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] ) -> Tuple[FIDStatistics, FIDStatistics]: if npz_path.endswith('.npz'): obj = np.load(npz_path) if "mu" in list(obj.keys()): return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( obj["mu_s"], obj["sigma_s"] ) return tuple(self.compute_statistics(x) for x in activations) def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: mu = np.mean(activations, axis=0) sigma = np.cov(activations, rowvar=False) return FIDStatistics(mu, sigma) def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: softmax_out = [] for i in range(0, len(activations), self.softmax_batch_size): acts = activations[i : i + self.softmax_batch_size] softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) preds = np.concatenate(softmax_out, axis=0) # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 scores = [] for i in range(0, len(preds), split_size): part = preds[i : i + split_size] kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) kl = np.mean(np.sum(kl, 1)) scores.append(np.exp(kl)) return float(np.mean(scores)) def compute_prec_recall( self, activations_ref: np.ndarray, activations_sample: np.ndarray ) -> Tuple[float, float]: radii_1 = self.manifold_estimator.manifold_radii(activations_ref) radii_2 = self.manifold_estimator.manifold_radii(activations_sample) pr = self.manifold_estimator.evaluate_pr( activations_ref, radii_1, activations_sample, radii_2 ) return (float(pr[0][0]), float(pr[1][0])) class ManifoldEstimator: """ A helper for comparing manifolds of feature vectors. Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 """ def __init__( self, session, row_batch_size=10000, col_batch_size=10000, nhood_sizes=(3,), clamp_to_percentile=None, eps=1e-5, ): """ Estimate the manifold of given feature vectors. :param session: the TensorFlow session. :param row_batch_size: row batch size to compute pairwise distances (parameter to trade-off between memory usage and performance). :param col_batch_size: column batch size to compute pairwise distances. :param nhood_sizes: number of neighbors used to estimate the manifold. :param clamp_to_percentile: prune hyperspheres that have radius larger than the given percentile. :param eps: small number for numerical stability. """ self.distance_block = DistanceBlock(session) self.row_batch_size = row_batch_size self.col_batch_size = col_batch_size self.nhood_sizes = nhood_sizes self.num_nhoods = len(nhood_sizes) self.clamp_to_percentile = clamp_to_percentile self.eps = eps def warmup(self): feats, radii = ( np.zeros([1, 2048], dtype=np.float32), np.zeros([1, 1], dtype=np.float32), ) self.evaluate_pr(feats, radii, feats, radii) def manifold_radii(self, features: np.ndarray) -> np.ndarray: num_images = len(features) # Estimate manifold of features by calculating distances to k-NN of each sample. radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) for begin1 in range(0, num_images, self.row_batch_size): end1 = min(begin1 + self.row_batch_size, num_images) row_batch = features[begin1:end1] for begin2 in range(0, num_images, self.col_batch_size): end2 = min(begin2 + self.col_batch_size, num_images) col_batch = features[begin2:end2] # Compute distances between batches. distance_batch[ 0 : end1 - begin1, begin2:end2 ] = self.distance_block.pairwise_distances(row_batch, col_batch) # Find the k-nearest neighbor from the current batch. radii[begin1:end1, :] = np.concatenate( [ x[:, self.nhood_sizes] for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) ], axis=0, ) if self.clamp_to_percentile is not None: max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) radii[radii > max_distances] = 0 return radii def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): """ Evaluate if new feature vectors are at the manifold. """ num_eval_images = eval_features.shape[0] num_ref_images = radii.shape[0] distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) max_realism_score = np.zeros([num_eval_images], dtype=np.float32) nearest_indices = np.zeros([num_eval_images], dtype=np.int32) for begin1 in range(0, num_eval_images, self.row_batch_size): end1 = min(begin1 + self.row_batch_size, num_eval_images) feature_batch = eval_features[begin1:end1] for begin2 in range(0, num_ref_images, self.col_batch_size): end2 = min(begin2 + self.col_batch_size, num_ref_images) ref_batch = features[begin2:end2] distance_batch[ 0 : end1 - begin1, begin2:end2 ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) # From the minibatch of new feature vectors, determine if they are in the estimated manifold. # If a feature vector is inside a hypersphere of some reference sample, then # the new sample lies at the estimated manifold. # The radii of the hyperspheres are determined from distances of neighborhood size k. samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) max_realism_score[begin1:end1] = np.max( radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 ) nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) return { "fraction": float(np.mean(batch_predictions)), "batch_predictions": batch_predictions, "max_realisim_score": max_realism_score, "nearest_indices": nearest_indices, } def evaluate_pr( self, features_1: np.ndarray, radii_1: np.ndarray, features_2: np.ndarray, radii_2: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """ Evaluate precision and recall efficiently. :param features_1: [N1 x D] feature vectors for reference batch. :param radii_1: [N1 x K1] radii for reference vectors. :param features_2: [N2 x D] feature vectors for the other batch. :param radii_2: [N x K2] radii for other vectors. :return: a tuple of arrays for (precision, recall): - precision: an np.ndarray of length K1 - recall: an np.ndarray of length K2 """ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool) features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool) for begin_1 in range(0, len(features_1), self.row_batch_size): end_1 = begin_1 + self.row_batch_size batch_1 = features_1[begin_1:end_1] for begin_2 in range(0, len(features_2), self.col_batch_size): end_2 = begin_2 + self.col_batch_size batch_2 = features_2[begin_2:end_2] batch_1_in, batch_2_in = self.distance_block.less_thans( batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] ) features_1_status[begin_1:end_1] |= batch_1_in features_2_status[begin_2:end_2] |= batch_2_in return ( np.mean(features_2_status.astype(np.float64), axis=0), np.mean(features_1_status.astype(np.float64), axis=0), ) class DistanceBlock: """ Calculate pairwise distances between vectors. Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 """ def __init__(self, session): self.session = session # Initialize TF graph to calculate pairwise distances. with session.graph.as_default(): self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) distance_block_16 = _batch_pairwise_distances( tf.cast(self._features_batch1, tf.float16), tf.cast(self._features_batch2, tf.float16), ) self.distance_block = tf.cond( tf.reduce_all(tf.math.is_finite(distance_block_16)), lambda: tf.cast(distance_block_16, tf.float32), lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), ) # Extra logic for less thans. self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) dist32 = tf.cast(self.distance_block, tf.float32)[..., None] self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) def pairwise_distances(self, U, V): """ Evaluate pairwise distances between two batches of feature vectors. """ return self.session.run( self.distance_block, feed_dict={self._features_batch1: U, self._features_batch2: V}, ) def less_thans(self, batch_1, radii_1, batch_2, radii_2): return self.session.run( [self._batch_1_in, self._batch_2_in], feed_dict={ self._features_batch1: batch_1, self._features_batch2: batch_2, self._radii1: radii_1, self._radii2: radii_2, }, ) def _batch_pairwise_distances(U, V): """ Compute pairwise distances between two batches of feature vectors. """ with tf.variable_scope("pairwise_dist_block"): # Squared norms of each row in U and V. norm_u = tf.reduce_sum(tf.square(U), 1) norm_v = tf.reduce_sum(tf.square(V), 1) # norm_u as a column and norm_v as a row vectors. norm_u = tf.reshape(norm_u, [-1, 1]) norm_v = tf.reshape(norm_v, [1, -1]) # Pairwise squared Euclidean distances. D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) return D class NpzArrayReader(ABC): @abstractmethod def read_batch(self, batch_size: int) -> Optional[np.ndarray]: pass @abstractmethod def remaining(self) -> int: pass def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: def gen_fn(): while True: batch = self.read_batch(batch_size) if batch is None: break yield batch rem = self.remaining() num_batches = rem // batch_size + int(rem % batch_size != 0) return BatchIterator(gen_fn, num_batches) class BatchIterator: def __init__(self, gen_fn, length): self.gen_fn = gen_fn self.length = length def __len__(self): return self.length def __iter__(self): return self.gen_fn() class StreamingNpzArrayReader(NpzArrayReader): def __init__(self, arr_f, shape, dtype): self.arr_f = arr_f self.shape = shape self.dtype = dtype self.idx = 0 def read_batch(self, batch_size: int) -> Optional[np.ndarray]: if self.idx >= self.shape[0]: return None bs = min(batch_size, self.shape[0] - self.idx) self.idx += bs if self.dtype.itemsize == 0: return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) read_count = bs * np.prod(self.shape[1:]) read_size = int(read_count * self.dtype.itemsize) data = _read_bytes(self.arr_f, read_size, "array data") return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) def remaining(self) -> int: return max(0, self.shape[0] - self.idx) class MemoryNpzArrayReader(NpzArrayReader): def __init__(self, arr): self.arr = arr self.idx = 0 @classmethod def load(cls, path: str, arr_name: str): with open(path, "rb") as f: arr = np.load(f)[arr_name] return cls(arr) def read_batch(self, batch_size: int) -> Optional[np.ndarray]: if self.idx >= self.arr.shape[0]: return None res = self.arr[self.idx : self.idx + batch_size] self.idx += batch_size return res def remaining(self) -> int: return max(0, self.arr.shape[0] - self.idx) @contextmanager def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: with _open_npy_file(path, arr_name) as arr_f: version = np.lib.format.read_magic(arr_f) if version == (1, 0): header = np.lib.format.read_array_header_1_0(arr_f) elif version == (2, 0): header = np.lib.format.read_array_header_2_0(arr_f) else: yield MemoryNpzArrayReader.load(path, arr_name) return shape, fortran, dtype = header if fortran or dtype.hasobject: yield MemoryNpzArrayReader.load(path, arr_name) else: yield StreamingNpzArrayReader(arr_f, shape, dtype) def _read_bytes(fp, size, error_template="ran out of data"): """ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 Read from file-like object until size bytes are read. Raises ValueError if not EOF is encountered before size bytes are read. Non-blocking objects only supported if they derive from io objects. Required as e.g. ZipExtFile in python 2.6 can return less data than requested. """ data = bytes() while True: # io files (default in python3) return None or raise on # would-block, python2 file will truncate, probably nothing can be # done about that. note that regular files can't be non-blocking try: r = fp.read(size - len(data)) data += r if len(r) == 0 or len(data) == size: break except io.BlockingIOError: pass if len(data) != size: msg = "EOF: reading %s, expected %d bytes got %d" raise ValueError(msg % (error_template, size, len(data))) else: return data @contextmanager def _open_npy_file(path: str, arr_name: str): with open(path, "rb") as f: with zipfile.ZipFile(f, "r") as zip_f: if f"{arr_name}.npy" not in zip_f.namelist(): raise ValueError(f"missing {arr_name} in npz file") with zip_f.open(f"{arr_name}.npy", "r") as arr_f: yield arr_f def _download_inception_model(): if os.path.exists(INCEPTION_V3_PATH): return print("downloading InceptionV3 model...") with requests.get(INCEPTION_V3_URL, stream=True) as r: r.raise_for_status() tmp_path = INCEPTION_V3_PATH + ".tmp" with open(tmp_path, "wb") as f: for chunk in tqdm(r.iter_content(chunk_size=8192)): f.write(chunk) os.rename(tmp_path, INCEPTION_V3_PATH) def _create_feature_graph(input_batch): _download_inception_model() prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" with open(INCEPTION_V3_PATH, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) pool3, spatial = tf.import_graph_def( graph_def, input_map={f"ExpandDims:0": input_batch}, return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], name=prefix, ) _update_shapes(pool3) spatial = spatial[..., :7] return pool3, spatial def _create_softmax_graph(input_batch): _download_inception_model() prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" with open(INCEPTION_V3_PATH, "rb") as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) (matmul,) = tf.import_graph_def( graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix ) w = matmul.inputs[1] logits = tf.matmul(input_batch, w) return tf.nn.softmax(logits) def _update_shapes(pool3): # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 ops = pool3.graph.get_operations() for op in ops: for o in op.outputs: shape = o.get_shape() if shape._dims is not None: # pylint: disable=protected-access # shape = [s.value for s in shape] TF 1.x shape = [s for s in shape] # TF 2.x new_shape = [] for j, s in enumerate(shape): if s == 1 and j == 0: new_shape.append(None) else: new_shape.append(s) o.__dict__["_shape_val"] = tf.TensorShape(new_shape) return pool3 def _numpy_partition(arr, kth, **kwargs): num_workers = min(cpu_count(), len(arr)) chunk_size = len(arr) // num_workers extra = len(arr) % num_workers start_idx = 0 batches = [] for i in range(num_workers): size = chunk_size + (1 if i < extra else 0) batches.append(arr[start_idx : start_idx + size]) start_idx += size with ThreadPool(num_workers) as pool: return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) if __name__ == "__main__": main() ================================================ FILE: projects/preprocess/image_latent_c2i.py ================================================ import os import torch import argparse import datetime import time import torchvision import logging import math import shutil import accelerate import torch import torch.utils.checkpoint import diffusers import numpy as np import torch.nn.functional as F import einops import json import os.path as osp import functools from PIL import Image from torch.cuda import amp from torch.utils.data import DataLoader, Dataset from omegaconf import OmegaConf from accelerate import Accelerator, skip_first_batches from accelerate.logging import get_logger from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from tqdm.auto import tqdm from diffusers import AutoencoderKL, AutoencoderDC from nit.utils.misc_utils import instantiate_from_config from torchvision import transforms from torchvision.datasets.folder import DatasetFolder, default_loader from torchvision.transforms.functional import hflip from safetensors.torch import save_file from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from nit.utils.model_utils import dc_ae_encode, sd_vae_encode logger = get_logger(__name__, log_level="INFO") # For Omegaconf Tuple def resolve_tuple(*args): return tuple(args) OmegaConf.register_new_resolver("tuple", resolve_tuple) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") # ----General Training Arguments---- parser.add_argument( "--config", type=str, default="", help="The config file for training.", ) parser.add_argument( "--project_dir", type=str, default="t2i_linear_attention", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) args = parser.parse_args() return args def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way by default: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~torchvision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super().__init__( root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, ) self.imgs = self.samples def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target, path class ImagenetDataDictWrapper(Dataset): def __init__(self, dataset): super().__init__() self.dataset = dataset def __getitem__(self, i): x, y, p = self.dataset[i] return {"jpg": x, "cls": y, "path": p} def __len__(self): return len(self.dataset) # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60 def get_train_sampler(global_batch_size, max_steps, resume_step): sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long) return sample_indices[resume_step * global_batch_size : ].tolist() class ImagenetLoader(): def __init__(self, data_config): super().__init__() self.batch_size = data_config.dataloader.batch_size self.num_workers = data_config.dataloader.num_workers transform = transforms.Compose([ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, data_config.dataset.resolution)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform)) self.test_dataset = None self.val_dataset = None def train_len(self): return len(self.train_dataset) def train_dataloader(self, global_batch_size, max_steps, resume_step): sampler = get_train_sampler( global_batch_size, max_steps, resume_step ) return DataLoader( self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, pin_memory=True, drop_last=True ) def test_dataloader(self): return None def val_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, pin_memory=True, drop_last=False ) def main(args): project_dir = args.project_dir config = OmegaConf.load(args.config) model_config = config.model data_config = config.data train_config = config.training config_dir = osp.join(project_dir, 'configs') checkpoint_dir = osp.join(project_dir, 'checkpoints') logging_dir = osp.join(project_dir, 'logs') sample_dir = osp.join(project_dir, 'samples') accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=train_config.gradient_accumulation_steps, mixed_precision=train_config.mixed_precision, log_with=train_config.tracker, project_config=accelerator_project_config, split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes ) # Handle the repository creation if accelerator.is_main_process: os.makedirs(project_dir, exist_ok=True) os.makedirs(config_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(logging_dir, exist_ok=True) os.makedirs(sample_dir, exist_ok=True) OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml")) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: diffusers.utils.logging.set_verbosity_info() else: diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) if train_config.allow_tf32: # for A100 torch.backends.cuda.matmul.allow_tf32 = True # Setup models weight_dtype = torch.float32 if 'sd-vae' in model_config.vae: sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype) sd_vae.eval() sd_vae.requires_grad_(False) encode_func = functools.partial(sd_vae_encode, sd_vae) elif 'dc-ae' in model_config.vae: dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype) dc_ae.eval() dc_ae.requires_grad_(False) encode_func = functools.partial(dc_ae_encode, dc_ae) # Setup Dataloader total_batch_size = ( data_config.dataloader.batch_size * accelerator.num_processes * train_config.gradient_accumulation_steps ) global_steps = 0 if train_config.resume_from_checkpoint: # normal read with safety check if train_config.resume_from_checkpoint != "latest": resume_from_path = os.path.basename(train_config.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(checkpoint_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None if resume_from_path is None: logger.info( f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run." ) train_config.resume_from_checkpoint = None else: global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps logger.info(f"Resuming from steps: {global_steps}") get_train_dataloader = ImagenetLoader(data_config) train_len = get_train_dataloader.train_len() train_config.max_train_steps = math.ceil(train_len / total_batch_size) train_dataloader = get_train_dataloader.train_dataloader( global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps, ) # Prepare Accelerate train_dataloader= accelerator.prepare(train_dataloader) # We need to recalculate our total training steps as the size of the training dataloader may have changed. logger.info("***** Running training *****") logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}") logger.info(f" Dataset Length = {get_train_dataloader.train_len()}") logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {train_config.max_train_steps}") # Potentially load in the weights and states from a previous save if train_config.resume_from_checkpoint and resume_from_path != None: accelerator.print(f"Resuming from checkpoint {resume_from_path}") accelerator.load_state(resume_from_path) # Only show the progress bar once on each machine. progress_bar = tqdm( range(0, train_config.max_train_steps), disable = not accelerator.is_local_main_process ) progress_bar.set_description("Optim Steps") progress_bar.update(global_steps) # prepare patch size and max sequence length # make directory os.makedirs(data_config.dataset.target_dir, exist_ok=True) def save_data(z, y, p): # p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG' # target_folder: 'target_dir/n01440764' target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2]) f_name = p.split('/')[-1].split('.')[0] os.makedirs(target_folder, exist_ok=True) single_data = dict(latent=z.contiguous(), label=y) save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors')) for step, batch in enumerate(train_dataloader, start=global_steps): for batch_key in batch.keys(): if not isinstance(batch[batch_key], (list, str)): batch[batch_key] = batch[batch_key].to(dtype=weight_dtype) x = batch['jpg'] y = batch['cls'] p = batch['path'] if 'sd-vae' in model_config.vae: z_ori = encode_func(x) z_hflip = encode_func(hflip(x)) z = torch.stack([z_ori, z_hflip], dim=1) elif 'dc-ae' in model_config.vae: z_ori = encode_func(x) z_hflip = encode_func(hflip(x)) z = torch.stack([z_ori, z_hflip], dim=1) for i in range(len(p)): save_data(z[i], y[i], p[i]) # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation if accelerator.sync_gradients: progress_bar.update(1) global_steps += 1 if global_steps % train_config.checkpointing_steps == 0: if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if train_config.checkpoints_total_limit is not None: checkpoints = os.listdir(checkpoint_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= train_config.checkpoints_total_limit: num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(checkpoint_dir, f"checkpoint-{global_steps}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") accelerator.wait_for_everyone() if global_steps >= train_config.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args) ================================================ FILE: projects/preprocess/image_nr_latent_c2i.py ================================================ import os import torch import argparse import datetime import time import torchvision import logging import math import shutil import accelerate import torch import torch.utils.checkpoint import diffusers import numpy as np import torch.nn.functional as F import einops import json import os.path as osp import functools from PIL import Image from torch.cuda import amp from torch.utils.data import DataLoader, Dataset from omegaconf import OmegaConf from accelerate import Accelerator, skip_first_batches from accelerate.logging import get_logger from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from tqdm.auto import tqdm from diffusers import AutoencoderKL, AutoencoderDC from nit.utils.misc_utils import instantiate_from_config from torchvision import transforms from torchvision.datasets.folder import DatasetFolder, default_loader from torchvision.transforms.functional import hflip from safetensors.torch import save_file from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from nit.utils.model_utils import dc_ae_encode, sd_vae_encode logger = get_logger(__name__, log_level="INFO") # For Omegaconf Tuple def resolve_tuple(*args): return tuple(args) OmegaConf.register_new_resolver("tuple", resolve_tuple) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") # ----General Training Arguments---- parser.add_argument( "--config", type=str, default="", help="The config file for training.", ) parser.add_argument( "--project_dir", type=str, default="t2i_linear_attention", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) args = parser.parse_args() return args def native_resolution_resize(pil_image, min_image_size, max_image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ w, h = pil_image.size if w * h < max_image_size**2: new_w = max(1, int(w/min_image_size)) * min_image_size new_h = max(1, int(h/min_image_size)) * min_image_size else: new_w = np.sqrt(w/h) * max_image_size new_h = new_w * h / w new_w = int(new_w/min_image_size) * min_image_size new_h = int(new_h/min_image_size) * min_image_size pil_image = pil_image.resize((new_w, new_h), resample=Image.Resampling.BICUBIC) return pil_image IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way by default: :: root/dog/xxx.png root/dog/xxy.png root/dog/[...]/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/[...]/asd932_.png This class inherits from :class:`~torchvision.datasets.DatasetFolder` so the same methods can be overridden to customize the dataset. Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, ): super().__init__( root, loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, ) self.imgs = self.samples def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target, path class ImagenetDataDictWrapper(Dataset): def __init__(self, dataset): super().__init__() self.dataset = dataset def __getitem__(self, i): x, y, p = self.dataset[i] return {"jpg": x, "cls": y, "path": p} def __len__(self): return len(self.dataset) # from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60 def get_train_sampler(global_batch_size, max_steps, resume_step): sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long) return sample_indices[resume_step * global_batch_size : ].tolist() class ImagenetLoader(): def __init__(self, data_config): super().__init__() self.batch_size = data_config.dataloader.batch_size self.num_workers = data_config.dataloader.num_workers transform = transforms.Compose([ transforms.Lambda(lambda pil_image: native_resolution_resize( pil_image, data_config.dataset.min_image_size, data_config.dataset.max_image_size )), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform)) self.test_dataset = None self.val_dataset = None def train_len(self): return len(self.train_dataset) def train_dataloader(self, global_batch_size, max_steps, resume_step): sampler = get_train_sampler( global_batch_size, max_steps, resume_step ) return DataLoader( self.train_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=self.num_workers, pin_memory=True, drop_last=False ) def test_dataloader(self): return None def val_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, pin_memory=True, drop_last=True ) def main(args): project_dir = args.project_dir config = OmegaConf.load(args.config) model_config = config.model data_config = config.data train_config = config.training config_dir = osp.join(project_dir, 'configs') checkpoint_dir = osp.join(project_dir, 'checkpoints') logging_dir = osp.join(project_dir, 'logs') sample_dir = osp.join(project_dir, 'samples') accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=train_config.gradient_accumulation_steps, mixed_precision=train_config.mixed_precision, log_with=train_config.tracker, project_config=accelerator_project_config, split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes ) # Handle the repository creation if accelerator.is_main_process: os.makedirs(project_dir, exist_ok=True) os.makedirs(config_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(logging_dir, exist_ok=True) os.makedirs(sample_dir, exist_ok=True) OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml")) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: diffusers.utils.logging.set_verbosity_info() else: diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) if train_config.allow_tf32: # for A100 torch.backends.cuda.matmul.allow_tf32 = True # Setup models weight_dtype = torch.float32 if 'sd-vae' in model_config.vae: sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype) sd_vae.eval() sd_vae.requires_grad_(False) encode_func = functools.partial(sd_vae_encode, sd_vae) elif 'dc-ae' in model_config.vae: dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype) dc_ae.eval() dc_ae.requires_grad_(False) encode_func = functools.partial(dc_ae_encode, dc_ae) # Setup Dataloader total_batch_size = ( data_config.dataloader.batch_size * accelerator.num_processes * train_config.gradient_accumulation_steps ) global_steps = 0 if train_config.resume_from_checkpoint: # normal read with safety check if train_config.resume_from_checkpoint != "latest": resume_from_path = os.path.basename(train_config.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(checkpoint_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None if resume_from_path is None: logger.info( f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run." ) train_config.resume_from_checkpoint = None else: global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps logger.info(f"Resuming from steps: {global_steps}") get_train_dataloader = ImagenetLoader(data_config) train_len = get_train_dataloader.train_len() train_config.max_train_steps = math.ceil(train_len / total_batch_size) train_dataloader = get_train_dataloader.train_dataloader( global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps, ) # Prepare Accelerate train_dataloader= accelerator.prepare(train_dataloader) # We need to recalculate our total training steps as the size of the training dataloader may have changed. logger.info("***** Running training *****") logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}") logger.info(f" Dataset Length = {get_train_dataloader.train_len()}") logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {train_config.max_train_steps}") # Potentially load in the weights and states from a previous save if train_config.resume_from_checkpoint and resume_from_path != None: accelerator.print(f"Resuming from checkpoint {resume_from_path}") accelerator.load_state(resume_from_path) # Only show the progress bar once on each machine. progress_bar = tqdm( range(0, train_config.max_train_steps), disable = not accelerator.is_local_main_process ) progress_bar.set_description("Optim Steps") progress_bar.update(global_steps) # prepare patch size and max sequence length # make directory os.makedirs(data_config.dataset.target_dir, exist_ok=True) def save_data(z, y, p): # p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG' # target_folder: 'target_dir/n01440764' target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2]) f_name = p.split('/')[-1].split('.')[0] os.makedirs(target_folder, exist_ok=True) single_data = dict(latent=z.contiguous(), label=y) save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors')) for step, batch in enumerate(train_dataloader, start=global_steps): for batch_key in batch.keys(): if not isinstance(batch[batch_key], (list, str)): batch[batch_key] = batch[batch_key].to(dtype=weight_dtype) x = batch['jpg'] y = batch['cls'] p = batch['path'] if 'sd-vae' in model_config.vae: z_ori = encode_func(x) z_hflip = encode_func(hflip(x)) z = torch.stack([z_ori, z_hflip], dim=1) elif 'dc-ae' in model_config.vae: z_ori = encode_func(x) z_hflip = encode_func(hflip(x)) z = torch.stack([z_ori, z_hflip], dim=1) for i in range(len(p)): save_data(z[i], y[i], p[i]) # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation if accelerator.sync_gradients: progress_bar.update(1) global_steps += 1 if global_steps % train_config.checkpointing_steps == 0: if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if train_config.checkpoints_total_limit is not None: checkpoints = os.listdir(checkpoint_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= train_config.checkpoints_total_limit: num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(checkpoint_dir, f"checkpoint-{global_steps}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") accelerator.wait_for_everyone() if global_steps >= train_config.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args) ================================================ FILE: projects/sample/sample_c2i_ddp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Samples a large number of images from a pre-trained SiT model using DDP. Subsequently saves a .npz file that can be used to compute FID and other evaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations For a simple single-GPU/CPU sampling script, see sample.py. """ import torch import torch.distributed as dist from diffusers.models import AutoencoderKL, AutoencoderDC from tqdm import tqdm import os from PIL import Image import numpy as np import math import functools import argparse from omegaconf import OmegaConf from einops import rearrange from nit.schedulers.flow_matching.samplers_c2i import euler_sampler, euler_maruyama_sampler from nit.utils import init_from_ckpt from nit.utils.misc_utils import instantiate_from_config from nit.utils.model_utils import sd_vae_decode, dc_ae_decode def create_npz_from_sample_folder(sample_dir, num=50_000): """ Builds a single .npz file from a folder of .png samples. """ samples = [] for i in tqdm(range(num), desc="Building .npz file from samples"): sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") sample_np = np.asarray(sample_pil).astype(np.uint8) samples.append(sample_np) samples = np.stack(samples) assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) npz_path = f"{sample_dir}.npz" np.savez(npz_path, arr_0=samples) print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") return npz_path def main(args): """ Run sampling. """ torch.backends.cuda.matmul.allow_tf32 = args.tf32 # True: fast but may lead to some small numerical differences assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" torch.set_grad_enabled(False) # Setup DDP:cd dist.init_process_group("nccl") rank = dist.get_rank() device = rank % torch.cuda.device_count() seed = args.global_seed * dist.get_world_size() + rank torch.manual_seed(seed) torch.cuda.set_device(device) print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") # setup dtype dtype = torch.bfloat16 # Load model: config = OmegaConf.load(args.config) model_config = config.model if 'dc-ae' in model_config.vae_dir: dc_ae = AutoencoderDC.from_pretrained(model_config.vae_dir).to(device) if args.slice_vae: dc_ae.enable_slicing() if args.slice_vae: dc_ae.enable_slicing() spatial_downsample = 32 decode_func = functools.partial(dc_ae_decode, dc_ae) elif 'sd-vae' in model_config.vae_dir: sd_vae = AutoencoderKL.from_pretrained(model_config.vae_dir).to(device) if args.slice_vae: sd_vae.enable_slicing() if args.slice_vae: sd_vae.enable_slicing() spatial_downsample = 8 decode_func = functools.partial(sd_vae_decode, sd_vae) else: raise assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" # image resolution patch_size = int(model_config.network.params.patch_size) latent_h = int(args.height / spatial_downsample / patch_size) latent_w = int(args.width / spatial_downsample / patch_size) if args.interpolation != 'no': model_config.network.params['custom_freqs'] = args.interpolation model_config.network.params['max_pe_len_h'] = latent_h model_config.network.params['max_pe_len_w'] = latent_w model_config.network.params['decouple'] = args.decouple model_config.network.params['ori_max_pe_len'] = int(args.ori_max_pe_len) model = instantiate_from_config(model_config.network).to(device=device, dtype=dtype) init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True) model.eval() # important! if args.ag_config != None and args.ag_ckpt != None: ag_config = OmegaConf.load(args.ag_config) ag_model_config = ag_config.model ag_model = instantiate_from_config(ag_model_config.network).to(device=device, dtype=dtype) init_from_ckpt(ag_model, checkpoint_dir=args.ag_ckpt, ignore_keys=None, verbose=True) ag_model.eval() # important! else: ag_model = None # Create folder to save samples: train_iter = args.ckpt.split('/')[-2].split('-')[-1] folder_name = f"{train_iter}-{args.height}x{args.width}-{args.mode}-{args.num_steps}-" \ f"cfg-{args.cfg_scale}-low-{args.guidance_low}-high-{args.guidance_high}" if ag_model != None: sample_folder_dir = f"{args.sample_dir}/ag-{folder_name}" else: sample_folder_dir = f"{args.sample_dir}/{folder_name}" if args.interpolation != 'no': sample_folder_dir += f'-{args.interpolation}' if rank == 0: os.makedirs(sample_folder_dir, exist_ok=True) print(f"Saving .png samples at {sample_folder_dir}") dist.barrier() # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: n = args.per_proc_batch_size global_batch_size = n * dist.get_world_size() # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size) if rank == 0: print(f"Total number of images that will be sampled: {total_samples}") print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}") assert total_samples % dist.get_world_size() == 0, "total_samples must be divisible by world_size" samples_needed_this_gpu = int(total_samples // dist.get_world_size()) assert samples_needed_this_gpu % n == 0, "samples_needed_this_gpu must be divisible by the per-GPU batch size" iterations = int(samples_needed_this_gpu // n) pbar = range(iterations) pbar = tqdm(pbar) if rank == 0 else pbar total = 0 for i in pbar: # Sample inputs: z = torch.randn( (n*latent_h*latent_w, model.in_channels, patch_size, patch_size), device=device, dtype=dtype ) y = torch.randint(0, args.num_classes, (n,), device=device) hw_list = torch.tensor([[latent_h, latent_w] for _ in range(n)], device=device, dtype=torch.int) seqlens = hw_list[:, 0] * hw_list[:, 1] cu_seqlens = torch.cat([ torch.tensor([0], device=hw_list.device, dtype=torch.int32), torch.cumsum(seqlens, dim=0, dtype=torch.int32) ]) can_pass = True for j in range(n): index = j * dist.get_world_size() + rank + total if not os.path.exists(f"{sample_folder_dir}/{index:06d}.png"): can_pass = False if can_pass: total += global_batch_size print('total: ', total) continue # Sample images: sampling_kwargs = dict( model=model, ag_model=ag_model, latents=z, y=y, hw_list=hw_list, num_steps=args.num_steps, heun=args.heun, cfg_scale=args.cfg_scale, guidance_low=args.guidance_low, guidance_high=args.guidance_high, path_type=args.path_type, ) with torch.no_grad(): if args.mode == "sde": samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32) elif args.mode == "ode": samples = euler_sampler(**sampling_kwargs).to(torch.float32) else: raise NotImplementedError samples = rearrange(samples, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=latent_h, w=latent_w) samples = decode_func(samples) samples = (samples + 1) / 2. samples = torch.clamp( 255. * samples, 0, 255 ).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() # Save samples to disk as individual .png files for i, sample in enumerate(samples): index = i * dist.get_world_size() + rank + total Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") total += global_batch_size # Make sure all processes have finished saving their samples before attempting to convert to .npz dist.barrier() if rank == 0: # create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) print("Done.") dist.barrier() dist.destroy_process_group() if __name__ == "__main__": parser = argparse.ArgumentParser() # seed parser.add_argument("--global-seed", type=int, default=0) # precision parser.add_argument("--tf32", action=argparse.BooleanOptionalAction, default=True, help="By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.") # logging/saving: parser.add_argument("--config", type=str, default=None, help="Optional config to a SiT checkpoint.") parser.add_argument("--ckpt", type=str, default=None, help="Optional path to a SiT checkpoint.") parser.add_argument("--sample-dir", type=str, default="workdir/c2i/samples") parser.add_argument("--ag-config", type=str, default=None) parser.add_argument("--ag-ckpt", type=str, default=None) # model parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--height", type=int, default=256) parser.add_argument("--width", type=int, default=256) parser.add_argument("--slice_vae", action=argparse.BooleanOptionalAction, default=False) # only for ode # number of samples parser.add_argument("--per-proc-batch-size", type=int, default=32) parser.add_argument("--num-fid-samples", type=int, default=50_000) # sampling related hyperparameters parser.add_argument("--mode", type=str, default="ode") parser.add_argument("--cfg-scale", type=float, default=1.5) parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"]) parser.add_argument("--num-steps", type=int, default=50) parser.add_argument("--heun", action=argparse.BooleanOptionalAction, default=False) # only for ode parser.add_argument("--guidance-low", type=float, default=0.) parser.add_argument("--guidance-high", type=float, default=1.) parser.add_argument("--interpolation", type=str, choices=['no', 'linear', 'ntk-aware', 'ntk-by-parts', 'yarn', 'ntk-aware-pro1', 'ntk-aware-pro2', 'scale1', 'scale2'], default='no') # interpolation parser.add_argument("--ori-max-pe-len", default=None, type=int) parser.add_argument("--decouple", default=False, action="store_true") # interpolation # will be deprecated parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False) # only for ode args = parser.parse_args() main(args) ================================================ FILE: projects/train/packed_trainer_c2i.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import copy import functools import gc import itertools import json import logging import math import os import time import random import shutil import importlib import csv import numpy as np import os.path as osp from pathlib import Path from typing import List, Union from packaging import version from tqdm.auto import tqdm from copy import deepcopy from omegaconf import OmegaConf from einops import rearrange, repeat import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision.transforms.functional as TF from torch.utils.data import default_collate, Dataset from torchvision import transforms from torchvision.transforms import Normalize import accelerate from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed import transformers import diffusers from diffusers.optimization import get_scheduler from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from diffusers import AutoencoderKL from timeit import default_timer as timer from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from nit.schedulers.flow_matching.loss import FlowMatchingLoss from nit.data.packed_c2i_data import C2ILoader from nit.utils.misc_utils import ( get_obj_from_str, get_dtype, instantiate_from_config ) from nit.utils.train_utils import ( update_ema, log_validation, ) from nit.utils.gpu_memory_monitor import build_gpu_memory_monitor # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.18.0.dev0") logger = get_logger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") # ----General Training Arguments---- parser.add_argument( "--config", type=str, default="", help="The config file for training.", ) parser.add_argument( "--project_dir", type=str, default="t2i_linear_attention", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--seed", type=int, default=None, help="A seed for reproducible training." ) args = parser.parse_args() return args def main(args): project_dir = args.project_dir config = OmegaConf.load(args.config) model_config = config.model data_config = config.data train_config = config.training config_dir = osp.join(project_dir, 'configs') checkpoint_dir = osp.join(project_dir, 'checkpoints') logging_dir = osp.join(project_dir, 'logs') sample_dir = osp.join(project_dir, 'samples') if getattr(train_config, 'fsdp_config', None) != None: import functools from torch.distributed.fsdp.fully_sharded_data_parallel import ( BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision, StateDictType, FullStateDictConfig, FullOptimStateDictConfig, ) from accelerate.utils import FullyShardedDataParallelPlugin from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy fsdp_cfg = train_config.fsdp_config if train_config.mixed_precision == "fp16": dtype = torch.float16 elif train_config.mixed_precision == "bf16": dtype = torch.bfloat16 else: dtype = torch.float32 fsdp_plugin = FullyShardedDataParallelPlugin( sharding_strategy = { 'FULL_SHARD': ShardingStrategy.FULL_SHARD, 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP, 'NO_SHARD': ShardingStrategy.NO_SHARD, 'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD, 'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2, }[fsdp_cfg.sharding_strategy], backward_prefetch = { 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE, 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST, }[fsdp_cfg.backward_prefetch], mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, ), auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params ), cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload), state_dict_type = { 'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT, 'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT, 'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT }[fsdp_cfg.state_dict_type], state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True), optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), limit_all_gathers = fsdp_cfg.limit_all_gathers, use_orig_params = fsdp_cfg.use_orig_params, sync_module_states = fsdp_cfg.sync_module_states, forward_prefetch = fsdp_cfg.forward_prefetch, activation_checkpointing = fsdp_cfg.activation_checkpointing, ) else: fsdp_plugin = None accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=train_config.gradient_accumulation_steps, mixed_precision=train_config.mixed_precision, log_with=train_config.tracker, project_config=accelerator_project_config, split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes fsdp_plugin=fsdp_plugin, ) # Handle the repository creation if accelerator.is_main_process: os.makedirs(project_dir, exist_ok=True) os.makedirs(config_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(logging_dir, exist_ok=True) os.makedirs(sample_dir, exist_ok=True) OmegaConf.save(config=config, f=osp.join(config_dir, "config.yaml")) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) if train_config.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True total_batch_size = ( data_config.dataloader.batch_size * accelerator.num_processes * train_config.gradient_accumulation_steps ) if train_config.scale_lr: learning_rate = ( train_config.learning_rate * total_batch_size / train_config.learning_rate_base_batch_size ) else: learning_rate = train_config.learning_rate # prepare model, dataloader, optimizer and scheduler model = instantiate_from_config(model_config.network).to(device=accelerator.device) model.train() if model_config.use_ema: ema_model = deepcopy(model) ema_model.train() ema_model.requires_grad_(False) # Handle mixed precision and device placement # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(model).dtype != torch.float32: raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}" ) if accelerator.is_main_process: total_params = 0 trainable_params = 0 projector_params = 0 for name, param in model.named_parameters(): print(name, param.requires_grad) total_params += param.numel() # Total number of elements in the parameter if param.requires_grad: # Check if the parameter is trainable trainable_params += param.numel() if 'projector' in name: projector_params += param.numel() print(trainable_params, total_params, total_params-projector_params, trainable_params/total_params) # Optimizer creation target_optimizer = train_config.optimizer.get('target', 'torch.optim.AdamW') optimizer = get_obj_from_str(target_optimizer)( model.parameters(), lr=learning_rate, **train_config.optimizer.get("params", dict()) ) # Dataset creation and data processing # Here, we compute not just the text embeddings but also the additional embeddings global_steps = 0 if train_config.resume_from_checkpoint: # normal read with safety check if train_config.resume_from_checkpoint != "latest": resume_from_path = os.path.basename(train_config.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(checkpoint_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None if resume_from_path is None: logger.info( f"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run." ) train_config.resume_from_checkpoint = None else: global_steps = int(resume_from_path.split("-")[1]) # gs not calculate the gradient_accumulation_steps logger.info(f"Resuming from steps: {global_steps}") get_train_dataloader = C2ILoader(data_config) train_dataloader = get_train_dataloader.train_dataloader( rank=accelerator.process_index, world_size=accelerator.num_processes, global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_steps=global_steps, seed=args.seed ) # LR Scheduler creation # Scheduler and math around the number of training steps. lr_scheduler = get_scheduler( train_config.lr_scheduler, optimizer=optimizer, num_warmup_steps=train_config.lr_warmup_steps, num_training_steps=train_config.max_train_steps, ) # Prepare for training # Prepare everything with our `accelerator`. if model_config.use_ema: ema_model, model, optimizer, lr_scheduler = accelerator.prepare( ema_model, model, optimizer, lr_scheduler ) else: model, optimizer, lr_scheduler = accelerator.prepare( model, optimizer, lr_scheduler ) # transport loss_fn = FlowMatchingLoss(**OmegaConf.to_container(model_config.transport)) if model_config.enc_type == 'radio': from nit.models.nvidia_radio.hubconf import radio_model encoder = radio_model(version=model_config.enc_dir, progress=True, support_packing=True) encoder.to(device=accelerator.device).eval() encoder.requires_grad_(False) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process and getattr(train_config, 'tracker', 'wandb') != None: tracker_project_name = project_dir.split('/')[-1] # accelerator.init_trackers("mcga", config=config, init_kwargs=train_config.tracker_kwargs) accelerator.init_trackers(tracker_project_name, config=config, init_kwargs=train_config.tracker_kwargs) # initialize GPU memory monitor before applying parallelisms to the model gpu_memory_monitor = build_gpu_memory_monitor(logger) gpu_mem_stats = gpu_memory_monitor.get_peak_stats() # 15. Train! logger.info("***** Running training *****") logger.info(f" Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}") logger.info(f" Dataset Length = {get_train_dataloader.train_len()}") logger.info(f" Instantaneous batch size per device = {data_config.dataloader.batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {train_config.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {train_config.max_train_steps}") logger.info( " GPU memory usage for model: " f"{gpu_mem_stats.max_reserved_gib:.2f}GiB" f"({gpu_mem_stats.max_reserved_pct:.2f}%)" ) gpu_memory_monitor.reset_peak_stats() data_loading_times = [] feat_enc_times = [] # Potentially load in the weights and states from a previous save if train_config.resume_from_checkpoint and resume_from_path != None: accelerator.print(f"Resuming from checkpoint {resume_from_path}") accelerator.load_state(resume_from_path) progress_bar = tqdm( range(0, train_config.max_train_steps), initial=global_steps, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_main_process, ) for batch in train_dataloader: time_last_log = timer() data_load_start = timer() # load dataset from batch batch_image = [image.to(accelerator.device) for image in batch['image']] batch_label = batch['label'].squeeze(0).to(accelerator.device, torch.int) packed_latent = batch['latent'].squeeze(0).to(accelerator.device) noises = torch.randn_like(packed_latent) hw_list = batch['hw_list'].squeeze(0).to(torch.int) batch_size = hw_list.shape[0] dropout_prob = model_config.network.params.class_dropout_prob num_classes = model_config.network.params.num_classes if dropout_prob > 0: drop_ids = torch.rand(batch_label.shape[0], device=accelerator.device) < dropout_prob batch_label = torch.where(drop_ids, num_classes, batch_label) data_loading_times.append(timer() - data_load_start) feat_enc_start = timer() zs = [] if model_config.enc_type == 'radio': with torch.no_grad(), accelerator.autocast(): raw_images = [(image.unsqueeze(0)+1.0)/2.0 for image in batch_image] _, z = encoder.forward_pack(raw_images) zs.append(z) feat_enc_times.append(timer() - feat_enc_start) with accelerator.accumulate(model): # forward and calculate loss model_kwargs = dict(y=batch_label, hw_list=hw_list) fm_loss, proj_loss = loss_fn(model, batch_size, packed_latent, noises, model_kwargs, use_dir_loss=True, zs=zs) loss = fm_loss + model_config.proj_coeff * proj_loss accelerator.backward(loss) if accelerator.sync_gradients and train_config.max_grad_norm > 0: all_norm = accelerator.clip_grad_norm_(model.parameters(), train_config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: # 20.4.15. Make EMA update to target student model parameters if model_config.use_ema: update_ema(ema_model, model, model_config.ema_decay) global_steps += 1 time_delta = timer() - time_last_log sps = batch_size / time_delta time_data_loading = np.mean(data_loading_times) time_feat_enc = np.mean(feat_enc_times) time_data_loading_pct = 100 * time_data_loading / time_delta time_feat_enc_pct = 100 * time_feat_enc / time_delta gpu_mem_stats = gpu_memory_monitor.get_peak_stats() if global_steps % train_config.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if accelerator.is_main_process and train_config.checkpoints_total_limit is not None: checkpoints = os.listdir(checkpoint_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= train_config.checkpoints_total_limit: num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = osp.join(checkpoint_dir, removing_checkpoint) try: shutil.rmtree(removing_checkpoint) except: pass save_path = osp.join(checkpoint_dir, f"checkpoint-{global_steps}") if accelerator.is_main_process: os.makedirs(save_path, exist_ok=True) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") if global_steps in train_config.checkpoint_list: save_path = os.path.join(checkpoint_dir, f"save-checkpoint-{global_steps}") if accelerator.is_main_process: os.makedirs(save_path, exist_ok=True) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") time.sleep(10) torch.cuda.empty_cache() if global_steps % train_config.validation_steps == 0: log_validation(model) logs = { # loss and lr "loss_denoising": fm_loss.detach().item(), "loss_projector": proj_loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], # time and status "sps": sps, "data_loading(s)": time_data_loading, "data_loading(%)": time_data_loading_pct, "time_feat_enc(s)": time_feat_enc, "time_feat_enc(%)": time_feat_enc_pct, "memory_max_active(GiB)": gpu_mem_stats.max_active_gib, "memory_max_active(%)": gpu_mem_stats.max_active_pct, "memory_max_reserved(GiB)": gpu_mem_stats.max_reserved_gib, "memory_max_reserved(%)": gpu_mem_stats.max_reserved_pct, "memory_num_alloc_retries": gpu_mem_stats.num_alloc_retries, "memory_num_ooms": gpu_mem_stats.num_ooms } if accelerator.sync_gradients and train_config.max_grad_norm > 0: logs.update({'grad_norm': all_norm.item()}) progress_bar.set_postfix(**logs) progress_bar.update(1) accelerator.log(logs, step=global_steps) if global_steps >= train_config.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args) ================================================ FILE: requirements.txt ================================================ diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested transformers>=4.44.2 # The development team is working on version 4.44.2 accelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested sentencepiece>=0.2.0 # T5 used numpy==1.26.0 streamlit>=1.38.0 # For streamlit web demo imageio==2.34.2 # For diffusers inference export video imageio-ffmpeg==0.5.1 # For diffusers inference export video moviepy==1.0.3 # For export video pillow==9.5.0 timm safetensors einops triton torchdiffeq ================================================ FILE: scripts/preprocess/preorocess_in1k_256x256.sh ================================================ NNODES=1 GPUS_PER_NODE=8 MASTER_ADDR="localhost" export MASTER_PORT=$((30000 + $RANDOM % 21000)) CMD=" \ projects/preprocess/image_latent_c2i.py \ --config configs/preprocess/imagenet1k_256x256.yaml \ --project_dir workdir/preprocess/imagenet1k_256x256 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/preprocess/preorocess_in1k_512x512.sh ================================================ NNODES=1 GPUS_PER_NODE=8 MASTER_ADDR="localhost" export MASTER_PORT=$((30000 + $RANDOM % 21000)) CMD=" \ projects/preprocess/image_latent_c2i.py \ --config configs/preprocess/imagenet1k_512x512.yaml \ --project_dir workdir/preprocess/imagenet1k_512x512 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/preprocess/preorocess_in1k_native_resolution.sh ================================================ NNODES=1 GPUS_PER_NODE=8 MASTER_ADDR="localhost" export MASTER_PORT=$((30000 + $RANDOM % 21000)) CMD=" \ projects/preprocess/image_nr_latent_c2i.py \ --config configs/preprocess/imagenet1k_native_resolution.yaml \ --project_dir workdir/preprocess/imagenet1k_native_resolution \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/sample/sample_256x256.sh ================================================ torchrun \ --nnodes 1 \ --nproc_per_node 8 \ projects/sample/sample_c2i_ddp.py \ --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \ --ckpt checkpoints/nit_xl_model_1000K.safetensors \ --sample-dir ./samples \ --height 256 \ --width 256 \ --per-proc-batch-size 32 \ --mode sde \ --num-steps 250 \ --cfg-scale 2.25 \ --guidance-low 0.0 \ --guidance-high 0.7 \ --slice_vae \ ================================================ FILE: scripts/sample/sample_512x512.sh ================================================ torchrun \ --nnodes 1 \ --nproc_per_node 8 \ projects/sample/sample_c2i_ddp.py \ --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \ --ckpt checkpoints/nit_xl_model_1000K.safetensors \ --sample-dir ./samples \ --height 512 \ --width 512 \ --per-proc-batch-size 32 \ --mode sde \ --num-steps 250 \ --cfg-scale 2.05 \ --guidance-low 0.0 \ --guidance-high 0.7 \ --slice_vae \ ================================================ FILE: scripts/sample/sample_768x768.sh ================================================ torchrun \ --nnodes 1 \ --nproc_per_node 8 \ projects/sample/sample_c2i_ddp.py \ --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \ --ckpt checkpoints/nit_xl_model_1000K.safetensors \ --sample-dir ./samples \ --height 768 \ --width 768 \ --per-proc-batch-size 32 \ --mode ode \ --num-steps 50 \ --cfg-scale 3.0 \ --guidance-low 0.0 \ --guidance-high 0.7 \ --slice_vae \ ================================================ FILE: scripts/train/train_b_model.sh ================================================ NNODES=1 GPUS_PER_NODE=2 MASTER_ADDR="localhost" export MASTER_PORT=60563 mkdir -p workdir/c2i/nit_b_pack_merge_radio_65536 CMD=" \ projects/train/packed_trainer_c2i.py \ --config configs/c2i/nit_b_pack_merge_radio_65536.yaml \ --project_dir workdir/c2i/nit_b_pack_merge_radio_65536 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/train/train_l_model.sh ================================================ NNODES=1 GPUS_PER_NODE=2 MASTER_ADDR="localhost" export MASTER_PORT=60563 mkdir -p workdir/c2i/nit_l_pack_merge_radio_16384 CMD=" \ projects/train/packed_trainer_c2i.py \ --config configs/c2i/nit_l_pack_merge_radio_16384.yaml \ --project_dir workdir/c2i/nit_l_pack_merge_radio_16384 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/train/train_s_model.sh ================================================ NNODES=1 GPUS_PER_NODE=2 MASTER_ADDR="localhost" export MASTER_PORT=60563 mkdir -p workdir/c2i/nit_s_pack_merge_radio_65536 CMD=" \ projects/train/packed_trainer_c2i.py \ --config configs/c2i/nit_s_pack_merge_radio_65536.yaml \ --project_dir workdir/c2i/nit_s_pack_merge_radio_65536 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/train/train_xl_model.sh ================================================ NNODES=1 GPUS_PER_NODE=8 MASTER_ADDR="localhost" export MASTER_PORT=60563 mkdir -p workdir/c2i/nit_xl_pack_merge_radio_16384 CMD=" \ projects/train/packed_trainer_c2i.py \ --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \ --project_dir workdir/c2i/nit_xl_pack_merge_radio_16384 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: scripts/train/train_xxl_model.sh ================================================ NNODES=1 GPUS_PER_NODE=8 MASTER_ADDR="localhost" export MASTER_PORT=60563 mkdir -p workdir/c2i/nit_xxl_pack_merge_radio_8192 CMD=" \ projects/train/packed_trainer_c2i.py \ --config configs/c2i/nit_xxl_pack_merge_radio_8192.yaml \ --project_dir workdir/c2i/nit_xxl_pack_merge_radio_8192 \ --seed 0 \ " TORCHLAUNCHER="torchrun \ --nnodes $NNODES \ --nproc_per_node $GPUS_PER_NODE \ --rdzv_id $RANDOM \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ " bash -c "$TORCHLAUNCHER $CMD" ================================================ FILE: setup.py ================================================ from setuptools import find_packages, setup setup( name='nit', version='0.0.1', description='', packages=find_packages(), install_requires=[ 'torch', 'numpy', ], ) ================================================ FILE: tools/download_dataset_256x256.sh ================================================ target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256" mkdir -p $target_dir base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-256x256" files=( "n01440764_n02097298.zip" "n02097474_n02667093.zip" "n02669723_n03530642.zip" "n03532672_n04239074.zip" "n04243546_n15075141.zip" ) for file in "${files[@]}"; do echo "download $file ..." wget -c "$base_url/$file" -O "$target_dir/$file" echo "download $file finished" echo "start unzip $file ..." unzip "$target_dir/$file" -d "$target_dir" echo "unzip $file finished" rm "$target_dir/$file" echo done echo "Successfully download all the sampler-meta" ================================================ FILE: tools/download_dataset_512x512.sh ================================================ target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512" mkdir -p $target_dir base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-512x512" files=( "n01440764_n01697457.zip" "n01698640_n01855672.zip" "n01860187_n02074367.zip" "n02077923_n02097298.zip" "n02097474_n02110063.zip" "n02110185_n02138441.zip" "n02165105_n02415577.zip" "n02417914_n02667093.zip" "n02669723_n02859443.zip" "n02860847_n03041632.zip" "n03042490_n03291819.zip" "n03297495_n03530642.zip" "n03532672_n03743016.zip" "n03759954_n03884397.zip" "n03887697_n04033901.zip" "n04033995_n04239074.zip" "n04243546_n04398044.zip" "n04399382_n04560804.zip" "n04562935_n07745940.zip" "n07747607_n15075141.zip" ) for file in "${files[@]}"; do echo "download $file ..." wget -c "$base_url/$file" -O "$target_dir/$file" echo "download $file finished" echo "start unzip $file ..." unzip "$target_dir/$file" -d "$target_dir" echo "unzip $file finished" rm "$target_dir/$file" echo done echo "Successfully download all the sampler-meta" ================================================ FILE: tools/download_dataset_data_meta.sh ================================================ target_dir="datasets/imagenet1k/data_meta" mkdir -p $target_dir base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/data_meta" files=( "dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl" "dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl" "dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl" "dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl" ) for file in "${files[@]}"; do echo "download $file ..." wget -c "$base_url/$file" -O "$target_dir/$file" echo "download $file finished" echo done echo "Successfully download all the data-meta" ================================================ FILE: tools/download_dataset_native_resolution.sh ================================================ target_dir="datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution" mkdir -p $target_dir base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-native-resolution" files=( "n01440764_n01855672.zip" "n01860187_n02097298.zip" "n02097474_n02138441.zip" "n02165105_n02667093.zip" "n02669723_n03041632.zip" "n03042490_n03530642.zip" "n03532672_n03884397.zip" "n03887697_n04239074.zip" "n04243546_n04560804.zip" "n04562935_n15075141.zip" ) for file in "${files[@]}"; do echo "download $file ..." wget -c "$base_url/$file" -O "$target_dir/$file" echo "download $file finished" echo "start unzip $file ..." unzip "$target_dir/$file" -d "$target_dir" echo "unzip $file finished" rm "$target_dir/$file" echo done echo "Successfully download all the sampler-meta" ================================================ FILE: tools/download_dataset_sampler_meta.sh ================================================ target_dir="datasets/imagenet1k/sampler_meta" mkdir -p $target_dir base_url="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/sampler_meta" files=( "dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json" "dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json" "dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json" "dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json" ) for file in "${files[@]}"; do echo "download $file ..." wget -c "$base_url/$file" -O "$target_dir/$file" echo "download $file finished" echo done echo "Successfully download all the sampler-meta" ================================================ FILE: tools/pack_dataset.py ================================================ import json from nit.data.pack import pack_dataset import argparse def create_pack(data_meta, max_seq_len, algorithm, split): max_seq_per_pack = max_seq_len with open(data_meta, 'r') as fp: ori_dataset = [json.loads(line) for i, line in enumerate(fp)] dataset_seq_lens = [] dataset_seq_idxs = [] for idx, data in enumerate(ori_dataset): seq_len = int(data['latent_h']*data['latent_w']) # patch_size=1 dataset_seq_lens.append(seq_len) dataset_seq_idxs.append(idx) total_length = len(ori_dataset) run_length = int(total_length / split) all_packed_indices = [] for i in range(split): seq_lens = dataset_seq_lens[i*run_length: (i+1)*run_length] seq_idxs = dataset_seq_idxs[i*run_length: (i+1)*run_length] packed_indices = pack_dataset( algorithm, max_seq_len, max_seq_per_pack, seq_lens, seq_idxs ) all_packed_indices.extend(packed_indices) sampler_json_name = data_meta.split('/')[-1].replace('_meta.jsonl', '') sampler_json_name = f"{sampler_json_name}_{algorithm}_{max_seq_len}.json" with open(f'datasets/imagenet1k/sampler_meta/{sampler_json_name}', 'w') as fp: json.dump(all_packed_indices, fp, indent=4) if __name__ == '__main__': parser = argparse.ArgumentParser() # seed parser.add_argument("--data-meta", type=str, default='datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl') parser.add_argument("--max-seq-len", type=int, default=16384) parser.add_argument("--algorithm", type=str, default='LPFHP') parser.add_argument("--split", type=int, default=1) args = parser.parse_args() create_pack( data_meta=args.data_meta, max_seq_len=args.max_seq_len, algorithm=args.algorithm, split=args.split )