Full Code of WZDTHU/NiT for AI

main 02fa292a7cd1 cached
91 files
565.6 KB
145.9k tokens
746 symbols
1 requests
Download .txt
Showing preview only (597K chars total). Download the full file or copy to clipboard to get everything.
Repository: WZDTHU/NiT
Branch: main
Commit: 02fa292a7cd1
Files: 91
Total size: 565.6 KB

Directory structure:
gitextract_jijj82e1/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── c2i/
│   │   ├── nit_b_pack_merge_radio_65536.yaml
│   │   ├── nit_l_pack_merge_radio_16384.yaml
│   │   ├── nit_s_pack_merge_radio_65536.yaml
│   │   ├── nit_xl_pack_merge_radio_16384.yaml
│   │   └── nit_xxl_pack_merge_radio_8192.yaml
│   └── preprocess/
│       ├── imagenet1k_256x256.yaml
│       ├── imagenet1k_512x512.yaml
│       └── imagenet1k_native_resolution.yaml
├── nit/
│   ├── data/
│   │   ├── pack/
│   │   │   ├── __init__.py
│   │   │   ├── ennlshp.py
│   │   │   ├── lpfhp.py
│   │   │   ├── nnlshp.py
│   │   │   └── spfhp.py
│   │   ├── packed_c2i_data.py
│   │   └── sampler_util.py
│   ├── models/
│   │   ├── c2i/
│   │   │   └── nit_model.py
│   │   ├── nvidia_radio/
│   │   │   ├── hubconf.py
│   │   │   └── radio/
│   │   │       ├── __init__.py
│   │   │       ├── adaptor_base.py
│   │   │       ├── adaptor_generic.py
│   │   │       ├── adaptor_mlp.py
│   │   │       ├── adaptor_registry.py
│   │   │       ├── block.py
│   │   │       ├── cls_token.py
│   │   │       ├── common.py
│   │   │       ├── conv.py
│   │   │       ├── dinov2_arch.py
│   │   │       ├── dual_hybrid_vit.py
│   │   │       ├── enable_cpe_support.py
│   │   │       ├── enable_damp.py
│   │   │       ├── enable_spectral_reparam.py
│   │   │       ├── eradio_model.py
│   │   │       ├── extra_models.py
│   │   │       ├── extra_timm_models.py
│   │   │       ├── feature_normalizer.py
│   │   │       ├── forward_intermediates.py
│   │   │       ├── hf_model.py
│   │   │       ├── input_conditioner.py
│   │   │       ├── open_clip_adaptor.py
│   │   │       ├── radio_model.py
│   │   │       ├── vision_transformer_xpos.py
│   │   │       ├── vit_patch_generator.py
│   │   │       └── vitdet.py
│   │   └── utils/
│   │       ├── convs.py
│   │       ├── funcs.py
│   │       ├── norms.py
│   │       └── pos_embeds/
│   │           ├── flash_attn_rotary.py
│   │           ├── rope.py
│   │           └── sincos.py
│   ├── schedulers/
│   │   └── flow_matching/
│   │       ├── loss.py
│   │       └── samplers_c2i.py
│   └── utils/
│       ├── __init__.py
│       ├── deepspeed_zero_to_fp32.py
│       ├── ema.py
│       ├── eval_utils.py
│       ├── freeze.py
│       ├── gpu_memory_monitor.py
│       ├── lr_scheduler.py
│       ├── misc_utils.py
│       ├── model_utils.py
│       ├── train_utils.py
│       ├── util.py
│       ├── video_utils.py
│       └── warp_pos_idx.py
├── projects/
│   ├── evaluate/
│   │   └── adm_evaluator.py
│   ├── preprocess/
│   │   ├── image_latent_c2i.py
│   │   └── image_nr_latent_c2i.py
│   ├── sample/
│   │   └── sample_c2i_ddp.py
│   └── train/
│       └── packed_trainer_c2i.py
├── requirements.txt
├── scripts/
│   ├── preprocess/
│   │   ├── preorocess_in1k_256x256.sh
│   │   ├── preorocess_in1k_512x512.sh
│   │   └── preorocess_in1k_native_resolution.sh
│   ├── sample/
│   │   ├── sample_256x256.sh
│   │   ├── sample_512x512.sh
│   │   └── sample_768x768.sh
│   └── train/
│       ├── train_b_model.sh
│       ├── train_l_model.sh
│       ├── train_s_model.sh
│       ├── train_xl_model.sh
│       └── train_xxl_model.sh
├── setup.py
└── tools/
    ├── download_dataset_256x256.sh
    ├── download_dataset_512x512.sh
    ├── download_dataset_data_meta.sh
    ├── download_dataset_native_resolution.sh
    ├── download_dataset_sampler_meta.sh
    └── pack_dataset.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# UV
#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#uv.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Ruff stuff:
.ruff_cache/

# PyPI configuration file
.pypirc

*.json
*.png 
*.jpg
/checkpoints
/workdir
/datasets
/wandb
/samples

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
<h1 align="center"> Native-Resolution Image Synthesis</h1>

<!-- 
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/representation-alignment-for-generation/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=representation-alignment-for-generation) -->



<div align="center">
  <a href="https://github.com/WZDTHU" target="_blank">ZiDong&nbsp;Wang</a><sup>1,2</sup> 
  &ensp; <b>&middot;</b> &ensp;
  <a href="http://leibai.site" target="_blank">Lei&nbsp;Bai</a><sup>2,*</sup> 
  &ensp; <b>&middot;</b> &ensp;
  <a href="https://xyue.io" target="_blank">Xiangyu&nbsp;Yue</a><sup>1</sup> 
  &ensp; <b>&middot;</b> &ensp;
  <a href="https://wlouyang.github.io" target="_blank">Wanli&nbsp;Ouyang</a><sup>1,2</sup>
  &ensp; <b>&middot;</b> &ensp;
  <a href="https://invictus717.github.io" target="_blank">Yiyuan&nbsp;Zhang</a><sup>1,2,*</sup> </b>
  
  <sup>1</sup> MMLab CUHK &emsp; <sup>2</sup>Shanghai AI Lab <br>
  <sup>*</sup>Correspondance &emsp; <br>
</div>
<h3 align="center">
[<a href="https://wzdthu.github.io/NiT">project page</a>]&emsp;
[<a href="https://arxiv.org/abs/2506.03131">arXiv</a>]&emsp;
[<a href="https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K">Dataset</a>]&emsp;
[<a href="https://huggingface.co/GoodEnough/NiT-Models">Model</a>]&emsp;

</h3>
<br>


<b>Summary</b>: We propose Native-resolution diffusion Transformer (NiT), a model that explicitly learns varing resolutions and aspect ratios within its denoising process. This significantly improves training efficiency and generalization capability. To the best of our knowledge, <b>NiT firstly attains SOTA results on both</b> $256\times256$ ($2.08$ <b>FID</b>) <b>and</b> $512\times512$ ($1.48$ <b>FID</b>) <b>benchmarks in class-guided ImageNet generation</b>. NiT can also generalizes to arbitrary resolutions and aspect ratios, such as $4.52$ FID on $1024\times1024$ resolution, $4.11$ FID on $432\times768$ resolution.


![Figure](./assets/teaser.png)

### 🚨 News


- `2025-9-18` NiT is accepted by NeurIPS 2025! 🍺

- `2025-6-3` We are delighted to introduce NiT, which is the first work to explicitly model native resolution image synthesis. We have released the code, pretrained models, and processed dataset of NiT.



### 1. Setup


First, clone the repo:
```bash
git clone https://github.com/WZDTHU/NiT.git && cd NiT
```

#### 1.1 Environment Setup

```bash
conda create -n nit_env python=3.10
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn
pip install -r requirements.txt
pip install -e .
```


#### 1.2 Model Zoo (WIP)

With a single model, NiT-XL can compete on multiple benchmarks and it achieves a dual SOTA on both ImageNet-$256\times256$ and $512\times512$ benchmarks.

| Model | Model Zoo | Model Size | FID-256x256 | FID-512x512 | FID-768x768 | FID-1024x1024 |
|---------------|------------|---------|------------|------------|------------|------------|
| NiT-XL-1000K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors) | 675M | 2.16 | 1.57 | 4.05 | 4.52 |
| NiT-XL-1500K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors) | 675M | 2.03 | 1.45 | - | - |


```bash
mkdir checkpoints
wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors" -O checkpoints/nit_xl_model_1000K.safetensors
wget -c "https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors" -O checkpoints/nit_xl_model_1500K.safetensors
```


### 2. Sampling 

#### 2.1 Sampling Hyper-parameters

The sampling hyper-parameters for NiT-XL-1000K are summarized as follows:
| Resolution | Solver | NFE | CFG - scale | CFG - interval | FID | sFID | IS | Prec. | Rec. |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| 256 × 256 | SDE | 250 | 2.25 | [0.0, 0.7] | 2.16 | 6.34 | 253.44 | 0.79 | 0.62 |
| 512 × 512 | SDE | 250 |  2.05 | [0.0, 0.7] | 1.57 | 4.13 | 260.69 | 0.81 | 0.63 |
| 768 × 768 | ODE | 50 | 3.0 | [0.0, 0.7] | 4.05 | 8.77 | 262.31 | 0.83 | 0.52 |
| 1024 × 1024 | ODE | 50 |  3.0 | [0.0, 0.8] | 4.52 | 7.99 | 286.87 | 0.82 | 0.50 |
| 1536 × 1536 | ODE | 50 |  3.5 | [0.0, 0.9] | 6.51 | 9.97 | 230.10 | 0.83 | 0.42 |
| 2048 × 2048 | ODE | 50 |  4.5 | [0.0, 0.9] | 24.76 | 18.02 | 131.36 | 0.67 | 0.46 |
| 320 × 960 | ODE | 50 |  4.0 | [0.0, 0.9] | 16.85 | 17.79 | 189.18 | 0.71 | 0.38 |
| 432 × 768 | ODE | 50 |  2.75 | [0.0, 0.7] | 4.11 | 10.30 | 254.71 | 0.83 | 0.55 |
| 480 × 640 | ODE | 50 |  2.75 | [0.0, 0.7] | 3.72 | 8.23 | 284.94 | 0.83 | 0.54 |
| 640 × 480 | ODE | 50 |  2.5 | [0.0, 0.7] | 3.41 | 8.07 | 259.06 | 0.83 | 0.56 |
| 768 × 432 | ODE | 50 |  2.85 | [0.0, 0.7] | 5.27 | 9.92 | 218.78 | 0.80 | 0.55 |
| 960 × 320 | ODE | 50 |  4.5 | [0.0, 0.9] | 9.90 | 25.78 | 255.95 | 0.74 | 0.40 |

#### 2.2 Sampling Scripts

Sampling with NiT-XL-1000K model for $256\times256$-resolution images: 
```bash
bash scripts/sample/sample_256x256.sh
```

Sampling with NiT-XL-1000K model for $512\times512$-resolution images: 
```bash
bash scripts/sample/sample_512x512.sh
```

Sampling with NiT-XL-1000K model for $768\times768$-resolution images: 
```bash
bash scripts/sample/sample_768x768.sh
```

### 3. Evaluation

The sampling generates a folder of samples to compute FID, Inception Score and
other metrics. 
<b>Note that we do not pack the generate samples as a `.npz` file, this does not affect the calculation of FID and other metrics.</b>
Please follow the [ADM's TensorFlow
evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations)
to setup the conda-environment and download the reference batch. 

```bash
wget -c "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" -O checkpoints/classify_image_graph_def.pb
```


Given the directory of the reference batch `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes:
```bash
python projects/evaluate/adm_evaluator.py $REFERENCE_DIR $SAMPLING_DIR
```



### 4. Training

#### 4.1 Dataset Setup

Currently, we provide all the [preprocessed dataset](https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K) for ImageNet1K. Please use the following commands to download the meta files and preprocessed latents.

```bash
mkdir datasets
mkdir datasets/imagenet1k

bash tools/download_dataset_256x256.sh
bash tools/download_dataset_512x512.sh
bash tools/download_dataset_native_resolution.sh
```

####  Preprocess ImageNet1K Locally

You can also preprocess the ImageNet1K dataset on your own. 
Take $256\times256$-image preprocess as example, you should first modify the `data_dir` as your local ImageNet1K directory in `configs/preprocess/imagenet1k_256x256.yaml`. 
Then run the preprocess script `scripts/preprocess/preorocess_in1k_256x256.sh`.
```bash
bash scripts/preprocess/preorocess_in1k_256x256.sh
```

The proprecessing procedure of $512\times512$-image and native-resolution image is similiar. 
Modify the corresponding config file and run the script.
```bash
bash scripts/preprocess/preorocess_in1k_512x512.sh
bash scripts/preprocess/preorocess_in1k_native_resolution.sh
```


#### 4.2 Packing 

As we pack multiple image instances with distinct resolution into one sequence, we need to pre-set the image indices of each pack before the training process. 

#### Download meta-info files
Down all the data-meta files firstly, which restore the height, width and other information of each image.
```bash
bash tools/download_dataset_data_meta.sh
```
The above command will download four the data-meta files on `datasets/imagenet1k/data_meta` directory:

- `dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl`: data-meta file for $256\times256$-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl`, data-meta file for $512\times512$-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`, data-meta file for native-resolution image data.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl`, a merged file of the above three files.

The first two items of the native-resolution-image data-meta file (`dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`) are as follows:
```json
{"image_file": "n01601694/n01601694_11629.JPEG", "latent_file": "n01601694/n01601694_11629.safetensors", "ori_w": 580, "ori_h": 403, "latent_h": 12, "latent_w": 18, "image_h": 384, "image_w": 576, "type": "native-resolution"}

{"image_file": "n01601694/n01601694_11799.JPEG", "latent_file": "n01601694/n01601694_11799.safetensors", "ori_w": 500, "ori_h": 350, "latent_h": 10, "latent_w": 15, "image_h": 320, "image_w": 480, "type": "native-resolution"}
```

#### Sampler-Meta Download

Given the maximum length $L$, we pre-set the image indices of each pack before training. 
Here we use the LPFHP (longest-pack-first histogram packing) algorithm to pack all the dataset.

You can download our preprocessed packed sampler-meta file using the following command.
```bash
bash tools/download_dataset_sampler_meta.sh
```
The above command will download three the data-meta files on `datasets/imagenet1k/sampler_meta` directory:
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json`: corresponds to $L=16384$.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json`: corresponds to $L=16384$. This is the setting in NiT-XL experiments.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json`, corresponds to $L=32768$.
- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json`, corresponds to $L=65536$.


#### Prepare the Packing (Sampler-Meta) on Your Own

NiT supports training with images of arbitrary resolutions and aspect ratios, you can also prepare the packing (sampler-meta) according to your own demands.

```bash
# generate the default sampler-meta
python tools/pack_dataset.py
# generate the sampelr-meta for fixed 256x256-resolution experiment with the maximum sequence length of 16384
python tools/pack_dataset.py --data-meta datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl --max-seq-len 16384
```



#### Download Image Encoder

For NiT-S (33M) model, we use RADIO-v2.5-H as image encoder for REPA-loss.
For other NiT models, we use RADIO-v2.5-H as our image encoder.

```bash
wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar" -O checkpoints/radio_v2.5-h.pth.tar

wget -c "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar" -O checkpoints/radio-v2.5-b_half.pth.tar
```


####  Training Scripts
The above steps setup the `packed_json`, `jsonl_dir`, and `latent_dirs` in `configs/c2i/nit_xl_pack_merge_radio_16384.yaml`. 
Before training, please specify the `image_dir` as the directory of ImageNet1K dataset in your own machine. 
To train the XL-model (675M): 
```bash
bash scripts/train/train_xl_model.sh
```

Specify the `image_dir` in `configs/c2i/nit_s_pack_merge_radio_65536.yaml` and train the base-model (131M):
```bash
bash scripts/train/train_s_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_b_pack_merge_radio_65536.yaml` and train the base-model (131M):
```bash
bash scripts/train/train_b_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_l_pack_merge_radio_16384.yaml` and train the base-model (457M):
```bash
bash scripts/train/train_l_model.sh
```
Specify the `image_dir` in `configs/c2i/nit_xxl_pack_merge_radio_8192.yaml` and train the xxl-model (1.37B):
```bash
bash scripts/train/train_xxl_model.sh
```




### Citations
If you find the project useful, please kindly cite: 
```bibtex
@article{wang2025native,
  title={Native-Resolution Image Synthesis}, 
  author={Wang, Zidong and Bai, Lei and Yue, Xiangyu and Ouyang, Wanli and Zhang, Yiyuan},
  year={2025},
  eprint={2506.03131},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}
```

### License
This project is licensed under the Apache-2.0 license.


================================================
FILE: configs/c2i/nit_b_pack_merge_radio_65536.yaml
================================================
model: 
  transport:
    path_type: linear
    prediction: v
    weighting: lognormal
  network:
    target: nit.models.c2i.nit_model.NiT
    params:
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 12
      hidden_size: 768
      patch_size: 1
      in_channels: 32
      num_heads: 12
      qk_norm: True
      encoder_depth: 4
      z_dim: 1280
      use_checkpoint: False
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  slice_vae: False
  tile_vae: False
  # repa encoder
  enc_type: radio
  enc_dir: checkpoints/radio_v2.5-h.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: improved_pack
  dataset:
    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json
    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
    latent_dirs: [
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
    ]
    image_dir: <Your imagenet1k directory>/train
  dataloader:
    num_workers: 4
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 'c2i'}}
  max_train_steps: 2000000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 5.0e-5
  learning_rate_base_batch_size: 1
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [200000, 500000, 100000, 150000]


================================================
FILE: configs/c2i/nit_l_pack_merge_radio_16384.yaml
================================================
model: 
  transport:
    path_type: linear
    prediction: v
    weighting: lognormal
  network:
    target: nit.models.c2i.nit_model.NiT
    params:
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 24
      hidden_size: 1024
      patch_size: 1
      in_channels: 32
      num_heads: 16
      qk_norm: True
      encoder_depth: 6
      z_dim: 1280
      use_checkpoint: False
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  slice_vae: False
  tile_vae: False
  # repa encoder
  enc_type: radio
  enc_dir: checkpoints/radio_v2.5-h.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: improved_pack
  dataset:
    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json
    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
    latent_dirs: [
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
    ]
    image_dir: <Your imagenet1k directory>/train
  dataloader:
    num_workers: 4
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 'c2i'}}
  max_train_steps: 2000000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 5.0e-5
  learning_rate_base_batch_size: 4
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [200000, 500000, 100000, 150000]


================================================
FILE: configs/c2i/nit_s_pack_merge_radio_65536.yaml
================================================
model: 
  transport:
    path_type: linear
    prediction: v
    weighting: lognormal
  network:
    target: nit.models.c2i.nit_model.NiT
    params:
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 12
      hidden_size: 384
      patch_size: 1
      in_channels: 32
      num_heads: 6
      qk_norm: True
      encoder_depth: 4
      z_dim: 768
      projector_dim: 768
      use_checkpoint: False
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  slice_vae: False
  tile_vae: False
  # repa encoder
  enc_type: radio
  enc_dir: checkpoints/radio-v2.5-b_half.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: improved_pack
  dataset:
    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json
    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
    latent_dirs: [
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
    ]
    image_dir: <Your imagenet1k directory>/train
  dataloader:
    num_workers: 4
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 'c2i'}}
  max_train_steps: 2000000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 5.0e-5
  learning_rate_base_batch_size: 1
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [200000, 500000, 100000, 150000]


================================================
FILE: configs/c2i/nit_xl_pack_merge_radio_16384.yaml
================================================
model: 
  transport:
    path_type: linear
    prediction: v
    weighting: lognormal
  network:
    target: nit.models.c2i.nit_model.NiT
    params:
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 28
      hidden_size: 1152
      patch_size: 1
      in_channels: 32
      num_heads: 16
      qk_norm: True
      encoder_depth: 8
      z_dim: 1280
      use_checkpoint: False
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  slice_vae: False
  tile_vae: False
  # repa encoder
  enc_type: radio
  enc_dir: checkpoints/radio_v2.5-h.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: improved_pack
  dataset:
    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json
    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
    latent_dirs: [
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
    ]
    image_dir: <Your imagenet1k directory>/train
  dataloader:
    num_workers: 4
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 'c2i'}}
  max_train_steps: 2000000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 5.0e-5
  learning_rate_base_batch_size: 4
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [200000, 500000, 100000, 150000]


================================================
FILE: configs/c2i/nit_xxl_pack_merge_radio_8192.yaml
================================================
model: 
  transport:
    path_type: linear
    prediction: v
    weighting: lognormal
  network:
    target: nit.models.c2i.nit_model.NiT
    params:
      class_dropout_prob: 0.1
      num_classes: 1000
      depth: 40
      hidden_size: 1536
      patch_size: 1
      in_channels: 32
      num_heads: 24
      qk_norm: True
      encoder_depth: 8
      z_dim: 1280
      use_checkpoint: False
      use_adaln_lora: True
      adaln_lora_dim: 512
  # pretrained_vae:
  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
  slice_vae: False
  tile_vae: False
  # repa encoder
  enc_type: radio
  enc_dir: checkpoints/radio_v2.5-h.pth.tar
  proj_coeff: 1.0
  # ema
  use_ema: True
  ema_decay: 0.9999
  
data:
  data_type: improved_pack
  dataset:
    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json
    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl
    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']
    latent_dirs: [
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',
      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',
    ]
    image_dir: <Your imagenet1k directory>/train
  dataloader:
    num_workers: 4
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 'c2i'}}
  max_train_steps: 1000000
  checkpointing_steps: 2000
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 5.0e-5
  learning_rate_base_batch_size: 4
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 0
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [200000, 500000, 100000]


================================================
FILE: configs/preprocess/imagenet1k_256x256.yaml
================================================
model:
  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers

data:
  dataset:
    data_dir: <Your imagenet1k directory>/train
    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256
    resolution: 256
  dataloader:
    num_workers: 1
    batch_size: 64  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 't2i'}}
  max_train_steps: 100000
  checkpointing_steps: 200
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 1.0e-4
  learning_rate_base_batch_size: 256
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 4000
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [20000, 40000, 60000, 80000]


================================================
FILE: configs/preprocess/imagenet1k_512x512.yaml
================================================
model:
  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers

data:
  dataset:
    data_dir: <Your imagenet1k directory>/train
    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512
    resolution: 512
  dataloader:
    num_workers: 1
    batch_size: 16  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 't2i'}}
  max_train_steps: 100000
  checkpointing_steps: 200
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 1.0e-4
  learning_rate_base_batch_size: 256
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 4000
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [20000, 40000, 60000, 80000]


================================================
FILE: configs/preprocess/imagenet1k_native_resolution.yaml
================================================
model:
  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers

data:
  dataset:
    data_dir: <Your imagenet1k directory>/train
    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution
    min_image_size: 32
    max_image_size: 2048
  dataloader:
    num_workers: 1
    batch_size: 1  # Batch size (per device) for the training dataloader.

  
  
training:
  tracker: null
  tracker_kwargs: {'wandb': {'group': 't2i'}}
  max_train_steps: 100000
  checkpointing_steps: 200
  checkpoints_total_limit: 2
  resume_from_checkpoint: latest
  learning_rate: 1.0e-4
  learning_rate_base_batch_size: 256
  scale_lr: True
  lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
  lr_warmup_steps: 4000
  gradient_accumulation_steps: 1
  optimizer: 
    target: torch.optim.AdamW
    params:
      # betas: ${tuple:0.9, 0.999}
      betas: [0.9, 0.95]
      weight_decay: 1.0e-2
      eps: 1.0e-6
  max_grad_norm: 1.0
  proportion_empty_prompts: 0.0
  mixed_precision: bf16 # ["no", "fp16", "bf16"]
  allow_tf32: True 
  validation_steps: 500
  checkpoint_list: [20000, 40000, 60000, 80000]


================================================
FILE: nit/data/pack/__init__.py
================================================
from .ennlshp import ENNLSHP
from .lpfhp import LPFHP
from .nnlshp import NNLSHP
from .spfhp import SPFHP

import json
import torch
import numpy as np
from tqdm import tqdm

def get_strategy(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens):
    def generate_histogram(dataset_seq_lens):
        histogram = np.zeros(max_seq_len, dtype=np.int64)
        seq_lens, counts = np.unique(np.array(dataset_seq_lens), return_counts=True)
        histogram[seq_lens - 1] = counts
        return histogram
    histogram = generate_histogram(dataset_seq_lens)
    if algorithm == "SPFHP":
        strategy = SPFHP(histogram, max_seq_len, max_seq_per_pack)
    elif algorithm == "LPFHP":
        strategy = LPFHP(histogram, max_seq_len, max_seq_per_pack)
    elif algorithm == 'ENNLSHP':
        strategy = ENNLSHP(histogram, max_seq_len, max_seq_per_pack)
    elif algorithm == 'NNLSHP':
        strategy = NNLSHP(histogram, max_seq_len, max_seq_per_pack)
    else:
        raise NotImplementedError("Algorithm type unsupported. Pass one of: LPFHP, SPFHP")
    return strategy

def pack_dataset(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens, dataset_seq_idxs):
    dataset_seqs = torch.stack([torch.tensor(dataset_seq_lens), torch.tensor(dataset_seq_idxs)])
    strategy_set, strategy_repeat_count = get_strategy(
        algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens
    )
    
    packed_indices = []
    run_iters = sum(strategy_repeat_count)
    progress_bar = tqdm(range(run_iters))
    for i in range(len(strategy_repeat_count)):
        strategy = strategy_set[i]
        for _ in range(strategy_repeat_count[i]):
            progress_bar.update(1)
            ref_inds = []
            for x in strategy:
                ref_ind = torch.argwhere(dataset_seqs[0] == x)[-1]
                dataset_seqs[0, ref_ind] = -1
                ref_inds.append(ref_ind)
            inds = dataset_seqs[1, ref_inds].ravel()
            packed_indices.append(inds.tolist())
    return packed_indices



================================================
FILE: nit/data/pack/ennlshp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/ennlshp.py
"""Extended Non-Negative least squares histogram-packing."""
import time
import numpy as np
from scipy import optimize, stats
from functools import lru_cache


def get_packing_matrix(strategy_set, max_sequence_length):
    num_strategies = len(strategy_set)
    A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)
    for i, strategy in enumerate(strategy_set):
        for seq_len in strategy:
            A[seq_len - 1, i] += 1
    return A


@lru_cache(maxsize=None)
def get_packing_strategies(start_length, minimum_increment, target_length, depth):
    gap = target_length - start_length
    strategies = []
    # Complete the packing with exactly 1 number
    if depth == 1:
        if gap >= minimum_increment:
            strategies.append([gap])
    # Complete the sample in "depth" steps, recursively
    else:
        for new in range(minimum_increment, gap + 1):
            new_gap = target_length - start_length - new
            if new_gap == 0:
                strategies.append([new])
            else:
                options = get_packing_strategies(start_length + new, new, target_length, depth - 1)
                for option in options:
                    if len(option) > 0:
                        strategies.append([new] + option)
    return strategies


def ENNLSHP(histogram, max_sequence_length, max_sequences_per_pack):
    # List all unique ways of packing to the desired maximum sequence length
    strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)
    # Get the packing matrix corresponding to this list of packing strategies
    A = get_packing_matrix(strategy_set, max_sequence_length)
    # Weights that penalize the residual by the number of resulting padding tokens.
    w0 = np.array([x + 1 for x in range(max_sequence_length)])
    # construct the packing matrix
    A_bar = np.zeros((2 * max_sequence_length, len(strategy_set) + max_sequence_length), "d")
    # Base weighted matrix
    A_bar[:max_sequence_length, : len(strategy_set)] = np.expand_dims(w0, -1) * A
    # Higher weight to avoid positive residual
    A_bar[max_sequence_length:, : len(strategy_set)] = np.expand_dims(10**6 * np.ones([max_sequence_length]), -1) * A
    # negative diagonal unity matrix for mapping to residual
    A_bar[max_sequence_length:, len(strategy_set) :] = np.expand_dims(
        10**6 * np.ones([max_sequence_length]), -1
    ) * np.ones((max_sequence_length, max_sequence_length))
    b_bar = np.zeros(2 * max_sequence_length)
    # Apply weighting to histogram vector
    b_bar[:max_sequence_length] = w0 * histogram
    b_bar[max_sequence_length:] = 10**6 * np.ones([max_sequence_length]) * histogram
    # Solve the packing problem
    start = time.time()
    strategy_residual, rnorm = optimize.nnls(A_bar, b_bar)
    strategy_repeat_count = strategy_residual[: len(strategy_set)]
    # Round the floating point solution to nearest integer
    strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)
    # Compute the residuals, shape: [max_sequence_length]
    residual = histogram - A @ strategy_repeat_count
    # Handle the left-over sequences; that is the positive part of residual
    unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]
    for l in unpacked_seqlen:
        strategy = sorted([l, max_sequence_length - l])  # the depth 1 strategy
        strategy_index = strategy_set.index(strategy)
        strategy_repeat_count[strategy_index] += residual[l - 1]
    # Re-compute the residual with the updated strategy_repeat_count
    # This should now be strictly < 0
    residual = histogram - A @ strategy_repeat_count
    # Add padding based on deficit (negative residual portion of residual)
    padding = np.where(residual < 0, -residual, 0)

    # Calculate some basic statistics
    duration = time.time() - start
    sequence_lengths = np.arange(1, max_sequence_length + 1)
    old_number_of_samples = histogram.sum()
    new_number_of_samples = int(strategy_repeat_count.sum())
    speedup_upper_bound = 1.0 / (
        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
    )
    num_padding_tokens_packed = (sequence_lengths * padding).sum()
    efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)
    print(
        f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
        f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
        f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n"
        f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
    )
    return strategy_set, strategy_repeat_count


================================================
FILE: nit/data/pack/lpfhp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/lpfhp.py
"""Longest-pack-first histogram-packing."""
from collections import defaultdict
import numpy as np
import time


def add_pack(pack, count, tmp, final, limit, offset, max_sequence_length=512):
    """Filter out packs that reached maximum length or number of components."""
    # sanity checks
    assert max_sequence_length - sum(pack) == offset, "Incorrect offset."
    assert offset >= 0, "Too small offset."
    assert offset < max_sequence_length, "Too large offset."
    if len(pack) == limit or offset == 0:
        final[offset].append((count, pack))
    else:
        tmp[offset].append((count, pack))


def LPFHP(histogram, max_sequence_length, max_sequences_per_pack, distribute=True):
    """Longest-pack-first histogram-packing."""
    start = time.time()
    reversed_histogram = np.flip(histogram)
    # Initialize main strategy data dictionary.
    # The key indicates how many tokens are left for full length.
    # The value is a list of tuples, consisting of counts and respective packs.
    # A pack is a (sorted) list of sequence length values that get concatenated.
    tmp_strategies_per_length = defaultdict(list)
    strategies_per_length = defaultdict(list)
    if max_sequences_per_pack == "max":
        max_sequences_per_pack = max_sequence_length
    # Index i indicates here, how much space is left, due to reversed histogram
    for i in range(max_sequence_length):
        n_sequences_to_bin = reversed_histogram[i]
        length_to_bin = max_sequence_length - i
        offset = 0  # smallest possible offset for perfect fit
        while n_sequences_to_bin > 0:
            if (length_to_bin + offset) in tmp_strategies_per_length:
                # extract worst pack that will get modified
                n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()
                # calculate how often the current sequence maximally fits in
                repeat = min(1 + offset // length_to_bin, max_sequences_per_pack - len(pack))
                # correct dependent on count
                while n_sequences_to_bin // repeat == 0:
                    repeat -= 1
                if not distribute:
                    repeat = 1
                new_pack = pack + [length_to_bin] * repeat
                count = min(n_sequences_to_pack, n_sequences_to_bin // repeat)
                if n_sequences_to_pack > count:
                    # old pack gets reduced
                    n_sequences_to_pack -= count
                    tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))
                    n_sequences_to_bin -= count * repeat
                else:
                    n_sequences_to_bin -= n_sequences_to_pack * repeat
                add_pack(
                    new_pack,
                    count,
                    tmp_strategies_per_length,
                    strategies_per_length,
                    max_sequences_per_pack,
                    offset - (repeat - 1) * length_to_bin,
                    max_sequence_length,
                )
                # clean up to speed up main key search
                if not tmp_strategies_per_length[length_to_bin + offset]:
                    tmp_strategies_per_length.pop(length_to_bin + offset)
                # reset offset in case best fit changed
                offset = 0
            else:
                offset += 1
            # Does not fit anywhere. Create new pack.
            if offset >= max_sequence_length - length_to_bin + 1:
                # similar repetition but no dependence on pack.
                repeat = min(max_sequence_length // length_to_bin, max_sequences_per_pack)
                while n_sequences_to_bin // repeat == 0:
                    repeat -= 1
                if not distribute:
                    repeat = 1
                add_pack(
                    [length_to_bin] * repeat,
                    n_sequences_to_bin // repeat,
                    tmp_strategies_per_length,
                    strategies_per_length,
                    max_sequences_per_pack,
                    max_sequence_length - length_to_bin * repeat,
                    max_sequence_length,
                )
                n_sequences_to_bin -= n_sequences_to_bin // repeat * repeat
    # merge all strategies
    for key in tmp_strategies_per_length:
        strategies_per_length[key].extend(tmp_strategies_per_length[key])
    # flatten strategies dictionary
    strategy_set = []
    strategy_repeat_count = []
    for key in strategies_per_length:
        for count, pack in strategies_per_length[key]:
            pack.reverse()
            strategy_set.append(pack)
            strategy_repeat_count.append(count)

    # Summarize efficiency of solution
    duration = time.time() - start
    sequence_lengths = np.arange(1, max_sequence_length + 1)
    strategy_repeat_count = np.array(strategy_repeat_count)
    n_strategies = len(strategy_set)
    old_number_of_samples = histogram.sum()
    new_number_of_samples = strategy_repeat_count.sum()
    sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])
    total_tokens = max_sequence_length * new_number_of_samples
    empty_tokens = sum(
        [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]
    )
    efficiency = 100 - empty_tokens / total_tokens * 100
    speedup_upper_bound = 1.0 / (
        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
    )

    print(
        f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
        f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
        f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}",
        f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
    )

    return strategy_set, strategy_repeat_count


================================================
FILE: nit/data/pack/nnlshp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/nnlshp.py
"""Non-Negative least squares histogram-packing."""
import time
import numpy as np
from scipy import optimize, stats
from functools import lru_cache


def get_packing_matrix(strategy_set, max_sequence_length):
    num_strategies = len(strategy_set)
    A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)
    for i, strategy in enumerate(strategy_set):
        for seq_len in strategy:
            A[seq_len - 1, i] += 1
    return A


@lru_cache(maxsize=None)
def get_packing_strategies(start_length, minimum_increment, target_length, depth):
    gap = target_length - start_length
    strategies = []
    # Complete the packing with exactly 1 number
    if depth == 1:
        if gap >= minimum_increment:
            strategies.append([gap])
    # Complete the sample in "depth" steps, recursively
    else:
        for new in range(minimum_increment, gap + 1):
            new_gap = target_length - start_length - new
            if new_gap == 0:
                strategies.append([new])
            else:
                options = get_packing_strategies(start_length + new, new, target_length, depth - 1)
                for option in options:
                    if len(option) > 0:
                        strategies.append([new] + option)
    return strategies


def NNLSHP(histogram, max_sequence_length, max_sequences_per_pack):
    # List all unique ways of packing to the desired maximum sequence length
    strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)
    # Get the packing matrix corresponding to this list of packing strategies
    A = get_packing_matrix(strategy_set, max_sequence_length)
    # Weights that penalize the residual on short sequences less.
    penalization_cutoff = 8
    w0 = np.ones([max_sequence_length])
    w0[:penalization_cutoff] = 0.09
    # Solve the packing problem
    start = time.time()
    strategy_repeat_count, rnorm = optimize.nnls(np.expand_dims(w0, -1) * A, w0 * histogram)
    # Round the floating point solution to nearest integer
    strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)
    # Compute the residuals, shape: [max_sequence_length]
    residual = histogram - A @ strategy_repeat_count
    # Handle the left-over sequences, that is the positive part of residual
    unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]
    for l in unpacked_seqlen:
        strategy = sorted([l, max_sequence_length - l])  # the depth 1 strategy
        strategy_index = strategy_set.index(strategy)
        strategy_repeat_count[strategy_index] += residual[l - 1]
    # Re-compute the residual with the updated strategy_repeat_count
    # This should now be strictly < 0
    residual = histogram - A @ strategy_repeat_count
    # Add padding based on deficit (negative residual portion of residual)
    padding = np.where(residual < 0, -residual, 0)

    # Calculate some basic statistics
    duration = time.time() - start
    sequence_lengths = np.arange(1, max_sequence_length + 1)
    old_number_of_samples = histogram.sum()
    new_number_of_samples = int(strategy_repeat_count.sum())
    speedup_upper_bound = 1.0 / (
        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
    )
    num_padding_tokens_packed = (sequence_lengths * padding).sum()
    efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)
    print(
        f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
        f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
        f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n"
        f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
    )

    return strategy_set, strategy_repeat_count


================================================
FILE: nit/data/pack/spfhp.py
================================================
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/spfhp.py
"""Shortest-pack-first histogram-packing."""
from collections import defaultdict
import numpy as np
import time


def add_pack(pack, count, tmp, final, limit, offset):
    """Filter out packs that reached maximum length or number of sequences."""
    if len(pack) == limit or offset == 0:
        final[offset].append((count, pack))
    else:
        tmp[offset].append((count, pack))


def SPFHP(histogram, max_sequence_length, max_sequences_per_pack):
    """Shortest-pack-first histogram-packing."""
    start = time.time()
    reversed_histogram = np.flip(histogram)
    # Initialize main strategy data dictionary.
    # The key indicates how many tokens are left for full length.
    # The value is a list of tuples, consisting of counts and respective packs.
    # A pack is a (sorted) list of sequence length values that get concatenated.
    tmp_strategies_per_length = defaultdict(list)
    strategies_per_length = defaultdict(list)
    # Index i indicates here, how much space is left, due to reversed histogram
    for i in range(max_sequence_length):
        n_sequences_to_bin = reversed_histogram[i]
        length_to_bin = max_sequence_length - i
        offset = i + 1  # largest possible offset
        while n_sequences_to_bin > 0:
            if (length_to_bin + offset) in tmp_strategies_per_length:
                # extract shortest pack that will get modified
                n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()
                new_pack = pack + [length_to_bin]
                count = min(n_sequences_to_pack, n_sequences_to_bin)
                if n_sequences_to_pack > n_sequences_to_bin:
                    # old pack gets reduced
                    n_sequences_to_pack -= n_sequences_to_bin
                    tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))
                    n_sequences_to_bin = 0
                else:
                    n_sequences_to_bin -= n_sequences_to_pack
                add_pack(
                    new_pack,
                    count,
                    tmp_strategies_per_length,
                    strategies_per_length,
                    max_sequences_per_pack,
                    offset,
                )
                # clean up to speed up main key search
                if not tmp_strategies_per_length[length_to_bin + offset]:
                    tmp_strategies_per_length.pop(length_to_bin + offset)
            else:
                offset -= 1
            # Does not fit anywhere. Create new pack.
            if offset < 0:
                add_pack(
                    [length_to_bin],
                    n_sequences_to_bin,
                    tmp_strategies_per_length,
                    strategies_per_length,
                    max_sequences_per_pack,
                    i,
                )
                n_sequences_to_bin = 0
    # merge all strategies
    for key in tmp_strategies_per_length:
        strategies_per_length[key].extend(tmp_strategies_per_length[key])
    # flatten strategies dictionary
    strategy_set = []
    strategy_repeat_count = []
    for key in strategies_per_length:
        for count, pack in strategies_per_length[key]:
            pack.reverse()
            strategy_set.append(pack)
            strategy_repeat_count.append(count)

    # Summarize efficiency of solution
    duration = time.time() - start
    sequence_lengths = np.arange(1, max_sequence_length + 1)
    strategy_repeat_count = np.array(strategy_repeat_count)
    n_strategies = len(strategy_set)
    old_number_of_samples = histogram.sum()
    new_number_of_samples = strategy_repeat_count.sum()
    sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])
    total_tokens = max_sequence_length * new_number_of_samples
    empty_tokens = sum(
        [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]
    )
    efficiency = 100 - empty_tokens / total_tokens * 100
    speedup_upper_bound = 1.0 / (
        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples
    )

    print(
        f"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\n",
        f"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\n",
        f"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\n",
        f"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.",
    )

    return strategy_set, np.array(strategy_repeat_count)


================================================
FILE: nit/data/packed_c2i_data.py
================================================
import os
import datetime
import torchvision
import numpy as np
import torch
import ast
import json
import time


from omegaconf import OmegaConf
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
from accelerate.logging import get_logger
from safetensors.torch import load_file
from einops import rearrange
from functools import partial
from torchvision.transforms.functional import hflip

from .sampler_util import get_train_sampler, get_packed_batch_sampler

logger = get_logger(__name__, log_level="INFO")

PATCH_SIZE = 1

def resize_arr(pil_image, height, width):
    pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)

    return pil_image

def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])

def packed_collate_fn(batch):
    packed_latent = []
    label = []
    hw_list = []
    image = []
    for data in batch:
        C, H, W = data['latent'].shape
        latent = rearrange(
            data['latent'], 'c (h p1) (w p2) -> (h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE
        )
        packed_latent.append(latent)
        hw_list.append([H/PATCH_SIZE, W/PATCH_SIZE])
        label.append(data['label'])
        image.append(data['image'])
    packed_latent = torch.concat(packed_latent)
    label = torch.tensor(label)
    hw_list = torch.tensor(hw_list, dtype=torch.int32)
    return dict(image=image, latent=packed_latent, label=label, hw_list=hw_list)



class ImprovedPackedImageNetLatentDataset(Dataset):
    def __init__(self, packed_json, jsonl_dir, data_types, latent_dirs, image_dir):
        super().__init__()
        assert len(data_types) == len(latent_dirs)
        self.type_to_dir = dict()
        for i, data_type in enumerate(data_types):
            self.type_to_dir[data_type] = latent_dirs[i]
        self.image_dir = image_dir

        with open(packed_json, 'r') as fp:
            self.packed_dataset = json.load(fp)

        with open(jsonl_dir, 'r') as fp:
            self.dataset = [json.loads(line) for line in fp]
        
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data_meta = self.dataset[index]
        
        data_item = dict()
        data_type = data_meta['type']
        latent_file = os.path.join(self.type_to_dir[data_type], data_meta['latent_file'])
        image_file = os.path.join(self.image_dir, data_meta['image_file'])

        data = load_file(latent_file)
        
        height = data_meta['latent_h'] * 16
        width = data_meta['latent_w'] * 16
        
        if data_type == 'native-resolution':
            preprocess = partial(resize_arr, height=height, width=width)
        else:
            assert height == width
            preprocess = partial(center_crop_arr, image_size=height)

        transform = transforms.Compose([
            transforms.Lambda(lambda pil_image: preprocess(pil_image=pil_image)),
            transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ])

        rand_idx = torch.randint(low=0, high=2, size=(1,)).item()
        data_item['image'] = transform(Image.open(image_file).convert("RGB"))[rand_idx]
        data_item['latent'] = data['latent'][rand_idx]
        data_item['label'] = data['label']
        return data_item


class C2ILoader():
    def __init__(self, data_config):
        super().__init__()

        self.batch_size = data_config.dataloader.batch_size
        self.num_workers = data_config.dataloader.num_workers

        self.data_type = data_config.data_type
        
    
        if data_config.data_type == 'improved_pack':
            self.train_dataset = ImprovedPackedImageNetLatentDataset(
                **OmegaConf.to_container(data_config.dataset)
            )
        else:
            raise NotImplementedError
        
        
        self.test_dataset = None
        self.val_dataset = None

    def train_len(self):
        return len(self.train_dataset)

    def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):
        sampler = get_train_sampler(
            self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed
        )
        if self.data_type == 'improved_pack':
            batch_sampler = get_packed_batch_sampler(
                self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed
            )
            return DataLoader(
                self.train_dataset,
                batch_sampler=batch_sampler,
                collate_fn=packed_collate_fn,
                num_workers=self.num_workers,
                pin_memory=True,
            )
        else:
            return DataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                sampler=sampler,
                num_workers=self.num_workers,
                pin_memory=True,
                drop_last=True,
            )
    def test_dataloader(self):
        return None

    def val_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=self.shuffle,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True
        )






================================================
FILE: nit/data/sampler_util.py
================================================
import torch
import json

# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60
def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps,
                      resume_step, seed):
    sample_indices = torch.empty([max_steps * global_batch_size // world_size],
                                 dtype=torch.long)
    epoch_id, fill_ptr, offs = 0, 0, 0
    while fill_ptr < sample_indices.size(0):
        g = torch.Generator()
        g.manual_seed(seed + epoch_id)
        epoch_sample_indices = torch.randperm(len(dataset), generator=g)
        epoch_id += 1
        epoch_sample_indices = epoch_sample_indices[
            (rank + offs) % world_size::world_size
        ]
        offs = (offs + world_size - len(dataset) % world_size) % world_size
        epoch_sample_indices = epoch_sample_indices[
            :sample_indices.size(0) - fill_ptr
        ]
        sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \
            epoch_sample_indices
        fill_ptr += epoch_sample_indices.size(0)
    return sample_indices[resume_step * global_batch_size // world_size:].tolist()




def get_packed_batch_sampler(
        dataset, rank, world_size, max_steps, resume_step, seed
    ):
    sample_indices = [None for _ in range(max_steps)]
    epoch_id, fill_ptr, offs = 0, 0, 0
    while fill_ptr < len(sample_indices):
        g = torch.Generator()
        g.manual_seed(seed + epoch_id)
        epoch_sample_indices = torch.randperm(len(dataset), generator=g)
        epoch_id += 1
        epoch_sample_indices = epoch_sample_indices[
            (rank + offs) % world_size::world_size
        ]
        offs = (offs + world_size - len(dataset) % world_size) % world_size
        epoch_sample_indices = epoch_sample_indices[
            :len(sample_indices) - fill_ptr
        ]
        sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [
            dataset[i] for i in epoch_sample_indices
        ]
        fill_ptr += epoch_sample_indices.size(0)
    return sample_indices[resume_step:]



================================================
FILE: nit/models/c2i/nit_model.py
================================================
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------

import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Mlp
from einops import rearrange, repeat
from flash_attn import flash_attn_varlen_func
from nit.models.utils.funcs import get_parameter_dtype
from nit.models.utils.pos_embeds.rope import VisionRotaryEmbedding, rotate_half
from typing import Optional

def modulate(x, shift, scale):
    return x * (1 + scale) + shift

def build_mlp(hidden_size, projector_dim, z_dim):
    return nn.Sequential(
                nn.Linear(hidden_size, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, projector_dim),
                nn.SiLU(),
                nn.Linear(projector_dim, z_dim),
            )
#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################            
class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size
    
    @staticmethod
    def positional_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        self.timestep_embedding = self.positional_embedding
        t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def forward(self, labels):
        embeddings = self.embedding_table(labels)
        return embeddings


#################################################################################
#                                 Attention Block                               #
#################################################################################

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, cu_seqlens, freqs_cos, freqs_sin) -> torch.Tensor:
        N, C = x.shape
        qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
        ori_dtype = qkv.dtype
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
        
        q = q * freqs_cos + rotate_half(q) * freqs_sin
        k = k * freqs_cos + rotate_half(k) * freqs_sin
        q, k = q.to(ori_dtype), k.to(ori_dtype)
        
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

        x = flash_attn_varlen_func(
            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
        ).reshape(N, -1)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x



#################################################################################
#                                 Core NiT Model                                #
#################################################################################

class NiTBlock(nn.Module):
    """
    A NiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(
            hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
        )
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(
            in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
            )
        use_adaln_lora = block_kwargs.get('use_adaln_lora', False)
        if use_adaln_lora:
            adaln_lora_dim = block_kwargs['adaln_lora_dim']
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, adaln_lora_dim, bias=True),
                nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True)
            )
        else:
            self.adaLN_modulation = nn.Sequential(
                nn.SiLU(),
                nn.Linear(hidden_size, 6 * hidden_size, bias=True)
            )

    def forward(self, x, c, cu_seqlens, freqs_cos, freqs_sin):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.adaLN_modulation(c).chunk(6, dim=-1)
        )
        x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin)
        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))

        return x


class FinalLayer(nn.Module):
    """
    The final layer of NiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)

        return x


class NiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        encoder_depth=4,
        projector_dim=2048,
        z_dim=768,
        use_checkpoint: bool = False,
        custom_freqs: str = 'normal',
        theta: int = 10000,
        max_pe_len_h: Optional[int] = None,
        max_pe_len_w: Optional[int] = None,
        decouple: bool = False,
        ori_max_pe_len: Optional[int] = None,
        **block_kwargs # fused_attn
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.num_classes = num_classes
        self.encoder_depth = encoder_depth
        self.use_checkpoint = use_checkpoint
        
        self.x_embedder = PatchEmbed(
            input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False
        )
        self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        self.rope = VisionRotaryEmbedding(
            head_dim=hidden_size//num_heads, custom_freqs=custom_freqs, theta=theta,
            max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple,
            ori_max_pe_len=ori_max_pe_len
        )

        self.projector = build_mlp(hidden_size, projector_dim, z_dim) 
        
        self.blocks = nn.ModuleList([
            NiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        
        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        # Initialize label embedding table:
        nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in NiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x, patch_size=None):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs
    
    def get_rope(self, hw_list):
        grids = []
        for h, w in hw_list:
            grid_h = torch.arange(h)
            grid_w = torch.arange(w)
            grid = torch.meshgrid(grid_h, grid_w, indexing='xy') 
            grid = torch.stack(grid, dim=0).reshape(2, -1)
            grids.append(grid)
        grids = torch.cat(grids, dim=-1)
        freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grids)
        return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)

    def forward(self, x, t, y, hw_list, return_zs=False, return_logvar=False):
        """
        Forward pass of NiT.
        x: (N, C, p, p) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        x = self.x_embedder(x)                  # (N, C, p, p) -> (N, 1, D), where T = H * W / patch_size ** 2
        x = x.squeeze(1)                        # (N, D)
        B = hw_list.shape[0]

        freqs_cos, freqs_sin = self.get_rope(hw_list)   # (N, D_h)
        seqlens = hw_list[:, 0] * hw_list[:, 1]
        cu_seqlens = torch.cat([
            torch.tensor([0], device=hw_list.device, dtype=torch.int), 
            torch.cumsum(seqlens, dim=0, dtype=torch.int)
        ])

        # timestep and class embedding
        t_embed = self.t_embedder(t)            # (B, D)
        y = self.y_embedder(y)                  # (B, D)
        c = t_embed + y                         # (B, D)
        
        # (B, D) -> (N, D)
        c = torch.cat([c[i].unsqueeze(0).repeat(seqlens[i], 1) for i in range(B)], dim=0)
        
        zs=[]
        for i, block in enumerate(self.blocks):
            if not self.use_checkpoint:
                x = block(x, c, cu_seqlens, freqs_cos, freqs_sin)   # (N, D)
            else:
                x = torch.utils.checkpoint.checkpoint(
                    self.ckpt_wrapper(block), x, c, cu_seqlens, freqs_cos, freqs_sin
                )  
            if (i + 1) == self.encoder_depth and return_zs:
                zs = [self.projector(x)]
        x = self.final_layer(x, c)              # (N, out_channels * patch_size ** 2)
        
        # (N, out_channels * patch_size ** 2) -> (N, out_channels, p, p)
        x = rearrange(x, 'n (c p1 p2) -> n c p1 p2', p1=self.patch_size, p2=self.patch_size)                  
        if return_zs:
            return x, zs
        else:
            return x  


    def ckpt_wrapper(self, module):
        def ckpt_forward(*inputs):
            outputs = module(*inputs)
            return outputs
        return ckpt_forward
    
    @property
    def dtype(self) -> torch.dtype:
        """
        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
        """
        return get_parameter_dtype(self)



================================================
FILE: nit/models/nvidia_radio/hubconf.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

dependencies = ["torch", "timm", "einops"]

import os
from typing import Dict, Any, Optional, Union, List
import warnings

import torch
from torch.hub import load_state_dict_from_url

from timm.models import clean_state_dict

from .radio.adaptor_registry import adaptor_registry
from .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from .radio.enable_damp import configure_damp_from_args
from .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .radio.radio_model import RADIOModel, create_model_from_args
from .radio.input_conditioner import get_default_conditioner
from .radio.vitdet import apply_vitdet_arch, VitDetArgs


def radio_model(
    version: str = "",
    progress: bool = True,
    adaptor_names: Union[str, List[str]] = None,
    vitdet_window_size: Optional[int] = None,
    return_checkpoint: bool = False,
    support_packing: bool=False,
    **kwargs,
) -> RADIOModel:
    if not version:
        version = DEFAULT_VERSION

    if os.path.isfile(version):
        chk = torch.load(version, map_location="cpu", weights_only=False)
        resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
    else:
        resource = RESOURCE_MAP[version]
        chk = load_state_dict_from_url(
            resource.url, progress=progress, map_location="cpu", weights_only=False,
        )

    if "state_dict_ema" in chk:
        state_dict = chk["state_dict_ema"]
        chk['args'].spectral_reparam = False
    else:
        state_dict = chk["state_dict"]

    args = chk["args"]
    args.support_packing = support_packing
    mod = create_model_from_args(args)

    mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")

    if args.spectral_reparam:
        configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)

    if getattr(args, 'damp', None):
        configure_damp_from_args(mod, args)

    state_dict = clean_state_dict(state_dict)

    key_warn = mod.load_state_dict(mod_state_dict, strict=False)
    if key_warn.missing_keys:
        warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
    if key_warn.unexpected_keys:
        warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')

    if chk['args'].spectral_reparam:
        # Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
        # the method is that instead of there being a `weight` tensor for certain Linear layers
        # in the model, we make it a dynamically computed function. During training, this
        # helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
        # Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
        # once, during this function call, and replace the parametrization with the realized weights.
        # This makes the model run faster, and also use less memory.
        disable_spectral_reparam(mod)
        chk['args'].spectral_reparam = False

    conditioner = get_default_conditioner()
    conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))

    dtype = getattr(chk['args'], 'dtype', torch.float32)
    mod.to(dtype=dtype)
    conditioner.dtype = dtype

    cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)
    if cls_token_per_teacher:
        name_to_idx_map = dict()
        for i, t in enumerate(chk['args'].teachers):
            if t.get('use_summary', True):
                name = t['name']
                if name not in name_to_idx_map:
                    name_to_idx_map[name] = i
        summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
    else:
        summary_idxs = torch.tensor([0], dtype=torch.int64)

    if adaptor_names is None:
        adaptor_names = []
    elif isinstance(adaptor_names, str):
        adaptor_names = [adaptor_names]

    teachers = chk["args"].teachers
    adaptors = dict()
    for adaptor_name in adaptor_names:
        for tidx, tconf in enumerate(teachers):
            if tconf["name"] == adaptor_name:
                break
        else:
            raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')

        ttype = tconf["type"]

        pf_idx_head = f'_heads.{tidx}'
        pf_name_head = f'_heads.{adaptor_name}'
        pf_idx_feat = f'_feature_projections.{tidx}'
        pf_name_feat = f'_feature_projections.{adaptor_name}'

        adaptor_state = dict()
        for k, v in state_dict.items():
            if k.startswith(pf_idx_head):
                adaptor_state['summary' + k[len(pf_idx_head):]] = v
            elif k.startswith(pf_name_head):
                adaptor_state['summary' + k[len(pf_name_head):]] = v
            elif k.startswith(pf_idx_feat):
                adaptor_state['feature' + k[len(pf_idx_feat):]] = v
            elif k.startswith(pf_name_feat):
                adaptor_state['feature' + k[len(pf_name_feat):]] = v

        adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
        adaptor.head_idx = tidx if cls_token_per_teacher else 0
        adaptors[adaptor_name] = adaptor

    feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
    feature_normalizer = None
    if feat_norm_sd:
        feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
        feature_normalizer.load_state_dict(feat_norm_sd)

    inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
    inter_feature_normalizer = None
    if inter_feat_norm_sd:
        inter_feature_normalizer = IntermediateFeatureNormalizer(
            *inter_feat_norm_sd['means'].shape[:2],
            rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
            dtype=dtype
        )
        inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

    radio = RADIOModel(
        mod,
        conditioner,
        summary_idxs=summary_idxs,
        patch_size=resource.patch_size,
        max_resolution=resource.max_resolution,
        window_size=vitdet_window_size,
        preferred_resolution=resource.preferred_resolution,
        adaptors=adaptors,
        feature_normalizer=feature_normalizer,
        inter_feature_normalizer=inter_feature_normalizer,
    )

    if vitdet_window_size is not None:
        apply_vitdet_arch(
            mod,
            VitDetArgs(
                vitdet_window_size,
                radio.num_summary_tokens,
                num_windowed=resource.vitdet_num_windowed,
                num_global=resource.vitdet_num_global,
            ),
        )

    if return_checkpoint:
        return radio, chk
    return radio


def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
    mod_state_dict = {
        k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
    }
    return mod_state_dict


================================================
FILE: nit/models/nvidia_radio/radio/__init__.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# Register the adaptors
from .adaptor_registry import adaptor_registry
from . import open_clip_adaptor
from .adaptor_base import AdaptorInput, RadioOutput, AdaptorBase

# Enable support for other model types via the timm register_model mechanism
from . import extra_timm_models
from . import extra_models
from . import vision_transformer_xpos


================================================
FILE: nit/models/nvidia_radio/radio/adaptor_base.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import NamedTuple, Optional

import torch
from torch import nn
import torch.nn.functional as F


class AdaptorInput(NamedTuple):
    images: torch.Tensor
    summary: torch.Tensor
    features: torch.Tensor
    feature_fmt: str
    patch_size: int


class RadioOutput(NamedTuple):
    summary: torch.Tensor
    features: torch.Tensor

    def to(self, *args, **kwargs):
        return RadioOutput(
            self.summary.to(*args, **kwargs) if self.summary is not None else None,
            self.features.to(*args, **kwargs) if self.features is not None else None,
        )


class AdaptorBase(nn.Module):
    def forward(self, input: AdaptorInput) -> RadioOutput:
        raise NotImplementedError("Subclasses must implement this!")


================================================
FILE: nit/models/nvidia_radio/radio/adaptor_generic.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace

import torch
from torch import nn
import torch.nn.functional as F

from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config


class GenericAdaptor(AdaptorBase):
    def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
        super().__init__()

        extra_args = dict()
        ups = None
        ups_rank = None
        if adaptor_config is not None:
            ups = adaptor_config.get('fd_upsample_factor', None)
            ups_rank = adaptor_config.get('fd_upsample_rank', None)
        elif mlp_config is not None:
            ups = mlp_config["feature"].get('upsample_factor', None)
            ups_rank = mlp_config["feature"].get('upsample_rank', None)
        if ups is not None:
            extra_args['upsample_factor'] = ups
            extra_args['upsample_rank'] = ups_rank

        if state is not None:
            spectral_heads = getattr(main_config, 'spectral_heads', False)
            self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
            self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
        else:
            assert mlp_config is not None, "Config must not be None if state is None"

            self.head_mlp =  create_mlp_from_config(
                main_config.mlp_version,
                mlp_config["summary"]["input_dim"],
                mlp_config["summary"]["hidden_dim"],
                mlp_config["summary"]["output_dim"],
                mlp_config["summary"]["num_inner"],
            )
            self.feat_mlp = create_mlp_from_config(
                main_config.mlp_version,
                mlp_config["feature"]["input_dim"],
                mlp_config["feature"]["hidden_dim"],
                mlp_config["feature"]["output_dim"],
                mlp_config["feature"]["num_inner"],
                **extra_args
            )

    def forward(self, input: AdaptorInput) -> RadioOutput:
        # Convert input'd type to the type of the first parameter of the adaptor.
        first_param = next(self.parameters())
        summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
        feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)

        if input.feature_fmt == 'NCHW':
            feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
                        .permute(0, 3, 1, 2)
            )

        return RadioOutput(summary, feat)


================================================
FILE: nit/models/nvidia_radio/radio/adaptor_mlp.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import math
from typing import Dict, Optional

import torch
from torch import nn

from einops import rearrange
from timm.models.vision_transformer import Block

from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam


class MLP(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int,
                 num_inner: int = 0, device: torch.device = None, **kwargs):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, device=device)
        self.norm = nn.LayerNorm(hidden_size, device=device)
        self.relu = nn.ReLU()

        inner = []
        for _ in range(num_inner):
            inner.extend([
                nn.Linear(hidden_size, hidden_size, device=device),
                nn.LayerNorm(hidden_size, device=device),
                nn.ReLU(),
            ])
        if inner:
            self.inner = nn.Sequential(*inner)
        else:
            self.inner = nn.Identity()

        self.fc2 = nn.Linear(hidden_size, output_size, device=device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.norm(x)
        x = self.relu(x)
        x = self.inner(x)
        x = self.fc2(x)
        return x


class MLP2(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int,
                 num_inner: int = 0,
                 pre_norm: bool = False, device: torch.device = None,
                 upsample_factor: int = 1,
                 upsample_rank: int = None,
                 from_config: bool = False,
                 **kwargs):
        super().__init__()

        self.pre_norm = nn.Sequential(
            nn.LayerNorm(input_size),
            nn.GELU(),
        ) if pre_norm else nn.Identity()

        self.upsample_factor = upsample_factor
        sq_ups = upsample_factor ** 2

        self._real_output_dim = output_size // sq_ups

        # hidden_size *= upsample_factor
        # output_size *= (upsample_factor ** 2)

        self.fc1 = nn.Linear(input_size, hidden_size, device=device)

        blocks = []
        for _ in range(num_inner):
            blocks.append(nn.Sequential(
                nn.LayerNorm(hidden_size, device=device),
                nn.GELU(),
                nn.Linear(hidden_size, hidden_size, device=device),
            ))
        self.blocks = nn.ModuleList(blocks)

        self.final = nn.Sequential(
            nn.LayerNorm(hidden_size, device=device),
            nn.GELU(),
            nn.Linear(hidden_size, output_size, device=device),
        )

    def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
        x = self.pre_norm(x)
        x = self.fc1(x)
        for block in self.blocks:
            x = x + block(x)
        x = self.final(x)

        if self.upsample_factor > 1:
            if images is None:
                raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
            if patch_size is None:
                raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
            h, w = tuple(d // patch_size for d in images.shape[-2:])
            x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
                          h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
                          c=self._real_output_dim)

        return x


MLP_FACTORY = {
    'v1': MLP,
    'v2': MLP2,
}


def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
    state = {
        k[len(prefix):]: v
        for k, v in state.items()
        if k.startswith(prefix)
    }
    return state


def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
    state = strip_prefix(state, prefix)

    weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'

    if version == 'v1':
        hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
        output_dim = state[f'fc2.{weight_suffix}'].shape[0]

        for num_inner in range(1000):
            k = f'inner.{num_inner}.0.weight'
            if k not in state:
                break
    elif version == 'v2':
        hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
        output_dim = state[f'final.2.{weight_suffix}'].shape[0]

        for num_inner in range(1000):
            k = f'blocks.{num_inner}.0.weight'
            if k not in state:
                break
    else:
        raise ValueError(f'Unsupported MLP version: {version}')

    return input_dim, hidden_dim, output_dim, num_inner


def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
    ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)

    return ret


def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
    state = strip_prefix(state, prefix)

    input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)

    ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)

    if spectral_weights:
        enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)

    ret.load_state_dict(state)

    if spectral_weights:
        disable_spectral_reparam(ret)

    return ret


================================================
FILE: nit/models/nvidia_radio/radio/adaptor_registry.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import Dict, Any

import torch

from .adaptor_generic import GenericAdaptor, AdaptorBase

dict_t = Dict[str, Any]
state_t = Dict[str, torch.Tensor]


class AdaptorRegistry:
    def __init__(self):
        self._registry = {}

    def register_adaptor(self, name):
        def decorator(factory_function):
            if name in self._registry:
                raise ValueError(f"Model '{name}' already registered")
            self._registry[name] = factory_function
            return factory_function
        return decorator

    def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
        if name not in self._registry:
            return GenericAdaptor(main_config, adaptor_config, state)
        return self._registry[name](main_config, adaptor_config, state)

# Creating an instance of the registry
adaptor_registry = AdaptorRegistry()


================================================
FILE: nit/models/nvidia_radio/radio/block.py
================================================
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Block modules
"""

import torch
import torch.nn as nn
from timm.models.layers import DropPath

from .conv import Conv
# from .transformer import TransformerBlock

__all__ = ('C2f', 'Bottleneck',)

class C2f(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        if drop_path is None:
            drop_path = [0.0] * n

        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))

    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))


class Bottleneck(nn.Module):
    """Standard bottleneck."""

    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0):  # ch_in, ch_out, shortcut, groups, kernels, expand
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        """'forward()' applies the YOLOv5 FPN to input data."""
        return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))


================================================
FILE: nit/models/nvidia_radio/radio/cls_token.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from typing import Optional

import torch
from torch import nn


class ClsToken(nn.Module):
    def __init__(self, ndim: int,
                 num_tokens: int = 1,
                 enabled: bool = True,
                 register_multiple: Optional[int] = None,
                 num_registers: Optional[int] = None,
    ):
        super().__init__()

        self.ndim = ndim
        self.enabled = enabled
        self.num_registers = 0
        self.num_tokens = num_tokens
        if enabled:
            if num_registers:
                self.num_registers = num_registers
            elif register_multiple:
                self.num_registers = register_multiple - (num_tokens % register_multiple)

            scale = ndim ** -0.5
            self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
        else:
            self.token = None

        self.num_patches = self.num_tokens + self.num_registers

    def disable(self):
        self.token = None
        self.enabled = False

    def forward(self, x: torch.Tensor):
        if self.token is None:
            return x

        token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
        x = torch.cat([
            token,
            x,
        ], dim=1)

        return x

    def no_weight_decay(self):
        return [
            'token',
        ]


================================================
FILE: nit/models/nvidia_radio/radio/common.py
================================================
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from dataclasses import dataclass
from typing import Optional

from .radio_model import Resolution


@dataclass
class RadioResource:
    url: str
    patch_size: int
    max_resolution: int
    preferred_resolution: Resolution
    vitdet_num_windowed: Optional[int] = None
    vitdet_num_global: Optional[int] = None


RESOURCE_MAP = {
    # RADIOv2.5
    "radio_v2.5-b": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=(768, 768),
        vitdet_num_global=4,
    ),
    "radio_v2.5-l": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=(768, 768),
        vitdet_num_global=4,
    ),
    "radio_v2.5-h": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=(768, 768),
        vitdet_num_global=4,
    ),
    "radio_v2.5-h-norm": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=(768, 768),
        vitdet_num_global=4,
    ),
    "radio_v2.5-g": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true",
        patch_size=14,
        max_resolution=1792,
        preferred_resolution=(896, 896),
        vitdet_num_global=8,
    ),
    # RADIO
    "radio_v2.1": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=Resolution(432, 432),
        vitdet_num_windowed=5,
    ),
    "radio_v2": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=Resolution(432, 432),
        vitdet_num_windowed=5,
    ),
    "radio_v1": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true",
        patch_size=14,
        max_resolution=1050,
        preferred_resolution=Resolution(378, 378),
    ),
    # E-RADIO
    "e-radio_v2": RadioResource(
        "https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=Resolution(512, 512),
    ),
    # C-RADIO
    "c-radio_v2.5-g": RadioResource(
        "https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=(768, 768),
        vitdet_num_global=8,
    ),
    "c-radio_v3-l": RadioResource(
        # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
        # and accept the license terms.
        "https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
        patch_size=16,
        max_resolution=2048,
        preferred_resolution=Resolution(512, 512),
    ),
}

DEFAULT_VERSION = "radio_v2.5-h"


================================================
FILE: nit/models/nvidia_radio/radio/conv.py
================================================
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Convolution modules
"""

import math

import numpy as np
import torch
import torch.nn as nn

__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
           'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')


def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

# Pavlo's implementation with switch to deploy
class Conv(nn.Module):
    default_act = nn.SiLU()  # default activation

    def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
        super().__init__()

        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
        if 1:
            self.bn = torch.nn.BatchNorm2d(b)
            torch.nn.init.constant_(self.bn.weight, bn_weight_init)
            torch.nn.init.constant_(self.bn.bias, 0)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()


    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

    @torch.no_grad()
    def switch_to_deploy(self):
        if not isinstance(self.bn, nn.Identity):
            # return 1
            c, bn = self.conv, self.bn
            w = bn.weight / (bn.running_var + bn.eps) ** 0.5
            w = c.weight * w[:, None, None, None]
            b = bn.bias - bn.running_mean * bn.weight / \
                (bn.running_var + bn.eps)**0.5
            # m = torch.nn.Conv2d(w.size(1) * c.groups,
            #                     w.size(0),
            #                     w.shape[2:],
            #                     stride=c.stride,
            #                     padding=c.padding,
            #                     dilation=c.dilation,
            #                     groups=c.groups)
            self.conv.weight.data.copy_(w)
            self.conv.bias = nn.Parameter(b)
            # self.conv.bias.data.copy_(b)
            # self.conv = m.to(c.weight.device)
            self.bn = nn.Identity()


================================================
FILE: nit/models/nvidia_radio/radio/dinov2_arch.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

# References:
#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

# Nvidia
# NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking,
# but also because Huggingface does a string replace of `gamma` to something else when loading the model state,
# and this breaks loading of this model.

from enum import Enum
from functools import partial
import logging
import math
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import warnings

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_

_torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention')


XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
    if XFORMERS_ENABLED:
        from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind

        XFORMERS_AVAILABLE = True
    else:
        raise ImportError
except ImportError:
    XFORMERS_AVAILABLE = False


def make_2tuple(x):
    if isinstance(x, tuple):
        assert len(x) == 2
        return x

    assert isinstance(x, int)
    return (x, x)


class PatchEmbed(nn.Module):
    """
    2D image to patch embedding: (B,C,H,W) -> (B,N,D)

    Args:
        img_size: Image size.
        patch_size: Patch token size.
        in_chans: Number of input image channels.
        embed_dim: Number of linear projection output channels.
        norm_layer: Normalization layer.
    """

    def __init__(
        self,
        img_size: Union[int, Tuple[int, int]] = 224,
        patch_size: Union[int, Tuple[int, int]] = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        norm_layer: Optional[Callable] = None,
        flatten_embedding: bool = True,
    ) -> None:
        super().__init__()

        image_HW = make_2tuple(img_size)
        patch_HW = make_2tuple(patch_size)
        patch_grid_size = (
            image_HW[0] // patch_HW[0],
            image_HW[1] // patch_HW[1],
        )

        self.img_size = image_HW
        self.patch_size = patch_HW
        self.patches_resolution = patch_grid_size
        self.num_patches = patch_grid_size[0] * patch_grid_size[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.flatten_embedding = flatten_embedding

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, _, H, W = x.shape
        patch_H, patch_W = self.patch_size

        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"

        x = self.proj(x)  # B C H W
        H, W = x.size(2), x.size(3)
        x = x.flatten(2).transpose(1, 2)  # B HW C
        x = self.norm(x)
        if not self.flatten_embedding:
            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
        return x

    def flops(self) -> float:
        Ho, Wo = self.patches_resolution
        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
        if self.norm is not None:
            flops += Ho * Wo * self.embed_dim
        return flops


class Attention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        qkv_bias: bool = False,
        proj_bias: bool = True,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
    ) -> None:
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim, bias=proj_bias)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]
        if _torch_has_sdpa:
            x = F.scaled_dot_product_attention(
                q, k, v,
                is_causal=False,
                dropout_p=self.attn_drop.p if self.training else 0.,
                scale=self.scale,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)

            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class MemEffAttention(Attention):
    def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
        if not XFORMERS_AVAILABLE:
            if attn_bias is not None:
                raise AssertionError("xFormers is required for using nested tensors")
            return super().forward(x)

        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)

        q, k, v = unbind(qkv, 2)

        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
        x = x.reshape([B, N, C])

        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop = nn.Dropout(drop)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class SwiGLUFFN(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Callable[..., nn.Module] = None,
        drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x12 = self.w12(x)
        x1, x2 = x12.chunk(2, dim=-1)
        hidden = F.silu(x1) * x2
        return self.w3(hidden)


if not XFORMERS_AVAILABLE:
    SwiGLU = SwiGLUFFN


class SwiGLUFFNFused(SwiGLU):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Callable[..., nn.Module] = None,
        drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
        super().__init__(
            in_features=in_features,
            hidden_features=hidden_features,
            out_features=out_features,
            bias=bias,
        )


def drop_path(x, drop_prob: float = 0.0, training: bool = False):
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0:
        random_tensor.div_(keep_prob)
    output = x * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class LayerScale(nn.Module):
    def __init__(
        self,
        dim: int,
        init_values: Union[float, torch.Tensor] = 1e-5,
        inplace: bool = False,
    ) -> None:
        super().__init__()
        self.inplace = inplace
        self.grandma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mul_(self.grandma) if self.inplace else x * self.grandma

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        # Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation
        # of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either
        # format
        key_a = f'{prefix}gamma'
        key_b = f'{prefix}grandma'
        if key_a in state_dict:
            gamma = state_dict[key_a]
        elif key_b in state_dict:
            gamma = state_dict[key_b]
        else:
            if strict:
                raise KeyError(f"Couldn't find the key {key_a} nor {key_b} in the state dict!")
            else:
                missing_keys.append(key_a)
                missing_keys.append(key_b)
                unexpected_keys.extend(state_dict.keys())
                gamma = None

        if gamma is not None:
            self.grandma.data.copy_(gamma)

        # return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class Block(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        proj_bias: bool = True,
        ffn_bias: bool = True,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        init_values=None,
        drop_path: float = 0.0,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
        attn_class: Callable[..., nn.Module] = Attention,
        ffn_layer: Callable[..., nn.Module] = Mlp,
    ) -> None:
        super().__init__()
        # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
        self.norm1 = norm_layer(dim)
        self.attn = attn_class(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            proj_bias=proj_bias,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ffn_layer(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
            bias=ffn_bias,
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.sample_drop_ratio = drop_path

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
            return self.ls1(self.attn(self.norm1(x)))

        def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
            return self.ls2(self.mlp(self.norm2(x)))

        if self.training and self.sample_drop_ratio > 0.1:
            # the overhead is compensated only for a drop path rate larger than 0.1
            x = drop_add_residual_stochastic_depth(
                x,
                residual_func=attn_residual_func,
                sample_drop_ratio=self.sample_drop_ratio,
            )
            x = drop_add_residual_stochastic_depth(
                x,
                residual_func=ffn_residual_func,
                sample_drop_ratio=self.sample_drop_ratio,
            )
        elif self.training and self.sample_drop_ratio > 0.0:
            x = x + self.drop_path1(attn_residual_func(x))
            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
        else:
            x = x + attn_residual_func(x)
            x = x + ffn_residual_func(x)
        return x


class NestedTensorBlock(Block):
    def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
        """
        x_list contains a list of tensors to nest together and run
        """
        assert isinstance(self.attn, MemEffAttention)

        if self.training and self.sample_drop_ratio > 0.0:

            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
                return self.attn(self.norm1(x), attn_bias=attn_bias)

            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
                return self.mlp(self.norm2(x))

            x_list = drop_add_residual_stochastic_depth_list(
                x_list,
                residual_func=attn_residual_func,
                sample_drop_ratio=self.sample_drop_ratio,
                scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None,
            )
            x_list = drop_add_residual_stochastic_depth_list(
                x_list,
                residual_func=ffn_residual_func,
                sample_drop_ratio=self.sample_drop_ratio,
                scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None,
            )
            return x_list
        else:

            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))

            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
                return self.ls2(self.mlp(self.norm2(x)))

            attn_bias, x = get_attn_bias_and_cat(x_list)
            x = x + attn_residual_func(x, attn_bias=attn_bias)
            x = x + ffn_residual_func(x)
            return attn_bias.split(x)

    def forward(self, x_or_x_list):
        if isinstance(x_or_x_list, torch.Tensor):
            return super().forward(x_or_x_list)
        elif isinstance(x_or_x_list, list):
            if not XFORMERS_AVAILABLE:
                raise AssertionError("xFormers is required for using nested tensors")
            return self.forward_nested(x_or_x_list)
        else:
            raise AssertionError


def drop_add_residual_stochastic_depth(
    x: torch.Tensor,
    residual_func: Callable[[torch.Tensor], torch.Tensor],
    sample_drop_ratio: float = 0.0,
) -> torch.Tensor:
    # 1) extract subset using permutation
    b, n, d = x.shape
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
    x_subset = x[brange]

    # 2) apply residual_func to get residual
    residual = residual_func(x_subset)

    x_flat = x.flatten(1)
    residual = residual.flatten(1)

    residual_scale_factor = b / sample_subset_size

    # 3) add the residual
    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
    return x_plus_residual.view_as(x)


def get_branges_scales(x, sample_drop_ratio=0.0):
    b, n, d = x.shape
    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
    residual_scale_factor = b / sample_subset_size
    return brange, residual_scale_factor


def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
    if scaling_vector is None:
        x_flat = x.flatten(1)
        residual = residual.flatten(1)
        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
    else:
        x_plus_residual = scaled_index_add(
            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
        )
    return x_plus_residual


attn_bias_cache: Dict[Tuple, Any] = {}


def get_attn_bias_and_cat(x_list, branges=None):
    """
    this will perform the index select, cat the tensors, and provide the attn_bias from cache
    """
    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
    if all_shapes not in attn_bias_cache.keys():
        seqlens = []
        for b, x in zip(batch_sizes, x_list):
            for _ in range(b):
                seqlens.append(x.shape[1])
        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
        attn_bias._batch_sizes = batch_sizes
        attn_bias_cache[all_shapes] = attn_bias

    if branges is not None:
        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
    else:
        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
        cat_tensors = torch.cat(tensors_bs1, dim=1)

    return attn_bias_cache[all_shapes], cat_tensors


def drop_add_residual_stochastic_depth_list(
    x_list: List[torch.Tensor],
    residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
    sample_drop_ratio: float = 0.0,
    scaling_vector=None,
) -> torch.Tensor:
    # 1) generate random set of indices for dropping samples in the batch
    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
    branges = [s[0] for s in branges_scales]
    residual_scale_factors = [s[1] for s in branges_scales]

    # 2) get attention bias and index+concat the tensors
    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)

    # 3) apply residual_func to get residual, and split the result
    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore

    outputs = []
    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
    return outputs


def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


class BlockChunk(nn.ModuleList):
    def forward(self, x):
        for b in self:
            x = b(x)
        return x


class DinoVisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=True,
        ffn_bias=True,
        proj_bias=True,
        drop_path_rate=0.0,
        drop_path_uniform=False,
        init_values=None,  # for layerscale: None or 0 => no layerscale
        embed_layer=PatchEmbed,
        act_layer=nn.GELU,
        block_fn=Block,
        ffn_layer="mlp",
        block_chunks=1,
        num_register_tokens=0,
        interpolate_antialias=False,
        interpolate_offset=0.1,
    ):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            proj_bias (bool): enable bias for proj in attn if True
            ffn_bias (bool): enable bias for ffn if True
            drop_path_rate (float): stochastic depth rate
            drop_path_uniform (bool): apply uniform drop rate across blocks
            weight_init (str): weight init scheme
            init_values (float): layer-scale init values
            embed_layer (nn.Module): patch embedding layer
            act_layer (nn.Module): MLP activation layer
            block_fn (nn.Module): transformer block class
            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
            num_register_tokens: (int) number of extra cls tokens (so-called "registers")
            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
        """
        super().__init__()
        norm_layer = partial(nn.LayerNorm, eps=1e-6)

        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 1
        self.n_blocks = depth
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.num_register_tokens = num_register_tokens
        self.interpolate_antialias = interpolate_antialias
        self.interpolate_offset = interpolate_offset

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        assert num_register_tokens >= 0
        self.register_tokens = (
            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
        )

        if drop_path_uniform is True:
            dpr = [drop_path_rate] * depth
        else:
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        if ffn_layer == "mlp":
            ffn_layer = Mlp
        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
            ffn_layer = SwiGLUFFNFused
        elif ffn_layer == "identity":
            def f(*args, **kwargs):
                return nn.Identity()

            ffn_layer = f
        else:
            raise NotImplementedError

        blocks_list = [
            block_fn(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer,
                ffn_layer=ffn_layer,
                init_values=init_values,
            )
            for i in range(depth)
        ]
        if block_chunks > 0:
            self.chunked_blocks = True
            chunked_blocks = []
            chunksize = depth // block_chunks
            for i in range(0, depth, chunksize):
                # this is to keep the block index consistent if we chunk the block list
                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
        else:
            self.chunked_blocks = False
            self.blocks = nn.ModuleList(blocks_list)

        self.norm = norm_layer(embed_dim)
        self.head = nn.Identity()

        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))

    def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        M = int(math.sqrt(N))  # Recover the number of patches in each dimension
        assert N == M * M
        kwargs = {}
        if self.interpolate_offset:
            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
            sx = float(w0 + self.interpolate_offset) / M
            sy = float(h0 + self.interpolate_offset) / M
            kwargs["scale_factor"] = (sx, sy)
        else:
            # Simply specify an output size instead of a scale factor
            kwargs["size"] = (w0, h0)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
            mode="bicubic",
            antialias=self.interpolate_antialias,
            **kwargs,
        )
        assert (w0, h0) == patch_pos_embed.shape[-2:]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

    def prepare_tokens_with_masks(self, x, masks=None):
        B, nc, w, h = x.shape
        x = self.patch_embed(x)
        if masks is not None:
            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        x = x + self.interpolate_pos_encoding(x, w, h)

        if self.register_tokens is not None:
            x = torch.cat(
                (
                    x[:, :1],
                    self.register_tokens.expand(x.shape[0], -1, -1),
                    x[:, 1:],
                ),
                dim=1,
            )

        return x

    def forward_features_list(self, x_list, masks_list):
        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
        for blk in self.blocks:
            x = blk(x)

        all_x = x
        output = []
        for x, masks in zip(all_x, masks_list):
            x_norm = self.norm(x)
            output.append(
                {
                    "x_norm_clstoken": x_norm[:, 0],
                    "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
                    "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
                    "x_prenorm": x,
                    "masks": masks,
                }
            )
        return output

    def forward_features(self, x, masks=None):
        if isinstance(x, list):
            return self.forward_features_list(x, masks)

        x = self.prepare_tokens_with_masks(x, masks)

        for blk in self.blocks:
            x = blk(x)

        x_norm = self.norm(x)
        return {
            "x_norm_clstoken": x_norm[:, 0],
            "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
            "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
            "x_prenorm": x,
            "masks": masks,
        }

    def _get_intermediate_layers_not_chunked(self, x, n=1):
        x = self.prepare_tokens_with_masks(x)
        # If n is an int, take the n last blocks. If it's a list, take them
        output, total_block_len = [], len(self.blocks)
        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if i in blocks_to_take:
                output.append(x)
        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
        return output

    def _get_intermediate_layers_chunked(self, x, n=1):
        x = self.prepare_tokens_with_masks(x)
        output, i, total_block_len = [], 0, len(self.blocks[-1])
        # If n is an int, take the n last blocks. If it's a list, take them
        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
        for block_chunk in self.blocks:
            for blk in block_chunk[i:]:  # Passing the nn.Identity()
                x = blk(x)
                if i in blocks_to_take:
                    output.append(x)
                i += 1
        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
        return output

    def get_intermediate_layers(
        self,
        x: torch.Tensor,
        n: Union[int, Sequence] = 1,  # Layers or n last layers to take
        reshape: bool = False,
        return_class_token: bool = False,
        norm=True,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
        if self.chunked_blocks:
            outputs = self._get_intermediate_layers_chunked(x, n)
        else:
            outputs = self._get_intermediate_layers_not_chunked(x, n)
        if norm:
            outputs = [self.norm(out) for out in outputs]
        class_tokens = [out[:, 0] for out in outputs]
        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
        if reshape:
            B, _, w, h = x.shape
            outputs = [
                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
                for out in outputs
            ]
        if return_class_token:
            return tuple(zip(outputs, class_tokens))
        return tuple(outputs)

    def forward(self, *args, is_training=False, **kwargs):
        ret = self.forward_features(*args, **kwargs)
        if is_training:
            return ret
        else:
            return self.head(ret["x_norm_clstoken"])


def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=1024,
        depth=24,
        num_heads=16,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
    """
    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
    """
    model = DinoVisionTransformer(
        patch_size=patch_size,
        embed_dim=1536,
        depth=40,
        num_heads=24,
        mlp_ratio=4,
        block_fn=partial(Block, attn_class=MemEffAttention),
        num_register_tokens=num_register_tokens,
        **kwargs,
    )
    return model


class Weights(Enum):
    LVD142M = "LVD142M"


def _make_dinov2_model(
    *,
    arch_name: str = "vit_large",
    img_size: int = 518,
    patch_size: int = 14,
    init_values: float = 1.0,
    ffn_layer: str = "mlp",
    block_chunks: int = 0,
    num_register_tokens: int = 0,
    interpolate_antialias: bool = False,
    interpolate_offset: float = 0.1,
    weights: Union[Weights, str] = Weights.LVD142M,
    **kwargs,
):
    if isinstance(weights, str):
        try:
            weights = Weights[weights]
        except KeyError:
            raise AssertionError(f"Unsupported weights: {weights}")

    vit_kwargs = dict(
        img_size=img_size,
        patch_size=patch_size,
        init_values=init_values,
        ffn_layer=ffn_layer,
        block_chunks=block_chunks,
        num_register_tokens=num_register_tokens,
        interpolate_antialias=interpolate_antialias,
        interpolate_offset=interpolate_offset,
    )
    vit_kwargs.update(**kwargs)
    model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs)

    return model


def dinov2_vits14(**kwargs):
    """
    DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(arch_name="vit_small", **kwargs)


def dinov2_vitb14(**kwargs):
    """
    DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(arch_name="vit_base", **kwargs)


def dinov2_vitl14(**kwargs):
    """
    DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(arch_name="vit_large", **kwargs)


def dinov2_vitg14(**kwargs):
    """
    DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(
        arch_name="vit_giant2",
        ffn_layer="swiglufused",
        **kwargs,
    )


def dinov2_vits14_reg(**kwargs):
    """
    DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(
        arch_name="vit_small",
        num_register_tokens=4,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        **kwargs,
    )


def dinov2_vitb14_reg(**kwargs):
    """
    DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(
        arch_name="vit_base",
        num_register_tokens=4,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        **kwargs,
    )


def dinov2_vitl14_reg(**kwargs):
    """
    DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(
        arch_name="vit_large",
        num_register_tokens=4,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        **kwargs,
    )


def dinov2_vitg14_reg(**kwargs):
    """
    DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
    """
    return _make_dinov2_model(
        arch_name="vit_giant2",
        ffn_layer="swiglufused",
        num_register_tokens=4,
        interpolate_antialias=True,
        interpolate_offset=0.0,
        **kwargs,
    )


================================================
FILE: nit/models/nvidia_radio/radio/dual_hybrid_vit.py
================================================
from logging import getLogger
from typing import Tuple

import torch
from torch import nn
from torch.nn import functional as F

from timm.models import register_model
from timm.models import vision_transformer as tvit
from timm.models import convnext as tconv

from einops import rearrange

from . import extra_timm_models as et


class Fuser(nn.Module):
    def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
        super().__init__()
        self.gated = gated

        mid_dim = max(src_dim, tgt_dim) * 2

        self.fwd = nn.Sequential(
            nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
            nn.GELU(),
            nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
        )

    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        if src.ndim == 3:
            shape = tgt.shape[-2:]
        else:
            shape = src.shape[-2:]

        nd = shape[0] * shape[1]

        if src.ndim == 3:
            src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)

        if tgt.ndim == 3:
            tgt_pre = tgt[:, :-nd]
            tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
        else:
            tgt_pre = None

        pred = self.fwd(src)

        if self.gated:
            g, pred = torch.chunk(pred, 2, dim=1)

            g = F.sigmoid(g)

            pred = g * pred

        tgt = tgt + pred

        if tgt_pre is not None:
            tgt = rearrange(tgt, 'b c h w -> b (h w) c')
            tgt = torch.cat([tgt_pre, tgt], dim=1)

        return tgt


class AttnDownsample(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int = 16):
        super().__init__()
        self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
        self.kv = nn.Linear(dim, dim * 2)
        self.proj = nn.Linear(dim, dim)
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

    def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
        ntok = twod_shape[0] * twod_shape[1]
        x_pre = x[:, :-ntok]

        B = x.shape[0]
        ds_hw = tuple(s // self.window_size for s in twod_shape)

        x_spat = rearrange(
            x[:, -ntok:],
            'b (h d1 w d2) c -> (b h w) (d1 d2) c',
            h=ds_hw[0], w=ds_hw[1],
            d1=self.window_size, d2=self.window_size,
        )

        B, N, C = x_spat.shape

        k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)

        q = (self.q * self.scale).expand(B, -1, -1, -1)
        attn = q @ k.transpose(-2, -1)
        attn = F.softmax(attn, dim=-1)
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, C)
        x = self.proj(x)

        x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])

        x = torch.cat([x_pre, x], dim=1)
        return x


class HybridModel(nn.Module):
    def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
                 concatenate: bool = False, **kwargs):
        super().__init__()
        self.conv = conv
        self.vit = vit
        self.concatenate = concatenate

        conv.stages = nn.ModuleList(conv.stages)
        vit.blocks = nn.ModuleList(vit.blocks)

        self._half_vit_idx = len(vit.blocks) // 2 + 1

        self._half_conv_idx = None
        x = torch.empty(1, 3, 256, 256)
        x = self.conv.stem(x)
        for i in range(len(conv.stages)):
            x = conv.stages[i](x)
            if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
                self._half_conv_idx = i + 1
                half_conv_dim = x.shape[1]
            final_conv_dim = x.shape[1]

        self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
        self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
        self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)

        embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
        if not concatenate:
            self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
        self.final_block = tvit.Block(embed_dim, num_heads=16)

        self.embed_dim = embed_dim

    @property
    def patch_size(self):
        return 32

    @property
    def no_fsdp_wrap_types(self):
        return {tvit.VisionTransformer, tconv.ConvNeXt}

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_features(x)

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        y_vit = self.vit.patch_generator(x)

        for i in range(self._half_vit_idx):
            y_vit = self.vit.blocks[i](y_vit)

        y_conv = self.conv.stem(x)
        for i in range(self._half_conv_idx):
            y_conv = self.conv.stages[i](y_conv)

        y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)

        y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])

        for i in range(self._half_vit_idx, len(self.vit.blocks)):
            y_vit = self.vit.blocks[i](y_vit)

        for i in range(self._half_conv_idx, len(self.conv.stages)):
            y_conv = self.conv.stages[i](y_conv)

        if self.concatenate:
            y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
            # Average pool across the board, and replicate for each cls/register token
            conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
            y_conv = torch.cat([conv_summary, y_conv], dim=1)
            y = torch.cat([y_vit, y_conv], dim=2)
        else:
            y = self.final_fuse(y_conv, y_vit)
        y = self.final_block(y)

        summary = y[:, :self.vit.patch_generator.num_cls_tokens]
        features = y[:, self.vit.patch_generator.num_cls_patches:]

        return summary, features


@register_model
def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
    cfg = dict(num_classes=0, **kwargs)
    conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
    vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)

    return HybridModel(vit, conv, pretrained, concatenate=concatenate)


@register_model
def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
    cfg = dict(num_classes=0, **kwargs)
    conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
    vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)

    return HybridModel(vit, conv, pretrained, concatenate=concatenate)


@register_model
def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
    cfg = dict(num_classes=0, **kwargs)
    conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
    vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)

    return HybridModel(vit, conv, pretrained, concatenate=concatenate)


================================================
FILE: nit/models/nvidia_radio/radio/enable_cpe_support.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from typing import List, Optional, Set, Tuple, Union
from types import MethodType

import torch
from torch import nn

from timm.models import VisionTransformer, checkpoint_seq
from timm.models.vision_transformer import Attention, Block

from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer

from .extra_models import DinoWrapper
from .vit_patch_generator import ViTPatchGenerator
from .forward_intermediates import forward_intermediates
from .dual_hybrid_vit import HybridModel
from flash_attn import flash_attn_varlen_func


def _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
    N, C = x.shape
    qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)
    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()

    x = flash_attn_varlen_func(
        q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen
    ).reshape(N, -1)

    x = self.proj(x)
    x = self.proj_drop(x)
    return x

def _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_seqlens)))
    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
    return x

def _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor]) -> torch.Tensor:
    device = images[0].device
    x = []
    seqlens = []
    for image in images:
        # image: [1, c, H, W] -> x: [n_cls+h*w, D], h=H/p and w=W/p
        _image = self.patch_generator(image).squeeze(0)
        x.append(_image)
        seqlens.append(_image.shape[0])
    
    x = torch.cat(x, dim=0)
    seqlens = torch.tensor(seqlens, device=device, dtype=torch.int)
    
    cu_seqlens = torch.cat([
        torch.tensor([0], device=device, dtype=torch.int32), 
        torch.cumsum(seqlens, dim=0, dtype=torch.int32)
    ])
    if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
        for block in self.blocks:
            x = checkpoint_seq(block, x, cu_seqlens)
    else:
        for block in self.blocks:
            x = block(x, cu_seqlens)
    x = self.norm(x)
    return x, cu_seqlens

def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
    x = self.patch_generator(x)
    if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
        x = checkpoint_seq(self.blocks, x)
    else:
        x = self.blocks(x)
    x = self.norm(x)
    return x


def _take_indices(
        num_blocks: int,
        n: Optional[Union[int, List[int], Tuple[int]]],
) -> Tuple[Set[int], int]:
    if isinstance(n, int):
        assert n >= 0
        take_indices = {x for x in range(num_blocks - n, num_blocks)}
    else:
        take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
    return take_indices, max(take_indices)


def _forward_intermediates_cpe(
        self,
        x: torch.Tensor,
        norm: bool = False,
        **kwargs,
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
    return forward_intermediates(
        self,
        patch_extractor=self.patch_generator,
        num_summary_tokens=self.patch_generator.num_skip,
        num_cls_tokens=self.patch_generator.num_cls_tokens,
        norm=self.norm if norm else lambda y: y,
        x=x,
        **kwargs,
    )


def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor:
    y = _forward_cpe(self.inner, x)

    return y[:, 0], y[:, self.num_summary_tokens:]


def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):
    return _forward_intermediates_cpe(self.inner, *args, **kwargs)


def _enable_cpe_for_timm_vit(model: VisionTransformer,
                             max_img_size: Union[int, Tuple[int, int]] = 1024,
                             num_cls_tokens: int = 1,
                             pos_dropout: float = 0.1,
                             register_multiple: int = Optional[None],
                             num_registers: int = Optional[None],
                             support_packing: bool = False,
):
    if not isinstance(model, VisionTransformer):
        raise ValueError("CPE only support for VisionTransformer models!")

    patch_size = model.patch_embed.patch_size[0]
    embed_dim = model.embed_dim
    input_dims = model.patch_embed.img_size
    normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
    cls_token = model.cls_token is not None

    max_img_size = int(round(max_img_size / patch_size) * patch_size)

    patch_generator = ViTPatchGenerator(
        patch_size=patch_size,
        embed_dim=embed_dim,
        input_dims=input_dims,
        normalize_patches=normalize_patches,
        cls_token=cls_token,
        max_input_dims=max_img_size,
        pos_dropout=pos_dropout,
        num_cls_tokens=num_cls_tokens,
        register_multiple=register_multiple,
        num_registers=num_registers,
    )

    model.patch_generator = patch_generator
    model.patch_embed = None
    model.cls_token = None
    model.pos_embed = None
    model.pos_drop = None
    model.patch_size = patch_size
    model.num_cls_tokens = num_cls_tokens
    model.num_registers = patch_generator.num_registers

    model.forward_features = MethodType(_forward_cpe, model)
    model.forward_intermediates = MethodType(_forward_intermediates_cpe, model)
    if support_packing:
        model.forward_features = MethodType(_forward_cpe_pack, model)
        for block in model.blocks:
            block.forward = MethodType(_block_forward_pack, block)
            block.attn.forward = MethodType(_attn_forward_pack, block.attn)


def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,
                                max_img_size: Union[int, Tuple[int, int]] = 1024,
                                num_cls_tokens: int = 1,
                                pos_dropout: float = 0.1,
                                register_multiple: int = Optional[None],
                                num_registers: int = Optional[None],
):
    patch_size = model.patch_size
    embed_dim = model.embed_dim
    input_dims = model.inner.patch_embed.patches_resolution
    normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity)
    cls_token = True

    max_img_size = int(round(max_img_size / patch_size) * patch_size)

    patch_generator = ViTPatchGenerator(
        patch_size=patch_size,
        embed_dim=embed_dim,
        input_dims=input_dims,
        normalize_patches=normalize_patches,
        cls_token=cls_token,
        max_input_dims=max_img_size,
        pos_dropout=pos_dropout,
        num_cls_tokens=num_cls_tokens,
        register_multiple=register_multiple,
        num_registers=num_registers,
        patch_bias=True,
    )

    inner = model.inner
    inner.patch_generator = patch_generator
    inner.patch_embed = None
    inner.cls_token = None
    inner.pos_embed = None
    inner.register_tokens = None
    inner.patch_size = patch_size

    model.forward_features = MethodType(_forward_cpe_dinov2, model)
    model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model)


def enable_cpe(model: nn.Module,
               *args,
               **kwargs,
):
    if isinstance(model, VisionTransformer):
        _enable_cpe_for_timm_vit(model, *args, **kwargs)
    elif isinstance(model, DinoWrapper):
        _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
    elif isinstance(model, HybridModel):
        _enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
    else:
        raise ValueError(f'CPE not supported for this model type: {type(model)}')


================================================
FILE: nit/models/nvidia_radio/radio/enable_damp.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from logging import getLogger
import math
import os
from typing import Dict, List, Optional, Union, Tuple
from types import MethodType

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import parametrize


# For now, don't do anything
class DAMP(nn.Identity):
    def __init__(self, std: float):
        super().__init__()
        self.std = std


def enable_damp(model: nn.Module, std: float):
    if isinstance(model, (list, tuple)):
        for m in model:
            enable_damp(m, std)
        return

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            parametrize.register_parametrization(module, 'weight', DAMP(std))


def configure_damp_from_args(model: nn.Module, args):
    damp = getattr(args, 'damp', None)
    if damp:
        enable_damp(model, damp)


================================================
FILE: nit/models/nvidia_radio/radio/enable_spectral_reparam.py
================================================
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from logging import getLogger
import math
import os
from typing import Dict, List, Optional, Union, Tuple
from types import MethodType

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import parametrize
from torch.nn.utils.parametrizations import _SpectralNorm

from timm.models.vision_transformer import Attention, Mlp

_EPS = 1e-5


class _SNReweight(_SpectralNorm):
    def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
        super().__init__(weight, *args, **kwargs)

        self.alpha = alpha
        self.version = version
        self.register_buffer('_sn_version', torch.tensor(version))

        if init_norm_to_current:
            # This will set the numerator to match the denominator, which should preserve the original values
            init_scale = self._get_sigma(weight, n_power_iterations=20).item()
        else:
            init_scale = 1.0

        if version == 1:
            init_value = init_scale
        elif version == 2:
            t = init_scale - alpha
            if t < _EPS:
                getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
                t = _EPS

            init_value = math.log(math.exp(t) - 1)
        else:
            raise ValueError(f'Unsupported version: {version}')

        # Make 2D so that weight decay gets applied
        self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))

    # Re-implementing this because we need to make division by sigma safe
    def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor:
        if not n_power_iterations:
            n_power_iterations = self.n_power_iterations
        if weight.ndim == 1:
            # Faster and more exact path, no need to approximate anything
            sigma = weight.norm()
        else:
            weight_mat = self._reshape_weight_to_matrix(weight)
            if self.training:
                self._power_method(weight_mat, n_power_iterations)
            # See above on why we need to clone
            u = self._u.clone(memory_format=torch.contiguous_format)
            v = self._v.clone(memory_format=torch.contiguous_format)
            # The proper way of computing this should be through F.bilinear, but
            # it seems to have some efficiency issues:
            # https://github.com/pytorch/pytorch/issues/58093
            sigma = torch.dot(u, torch.mv(weight_mat, v))

        return sigma + self.eps

    def forward(self, weight: torch.Tensor, *args, **kwargs):
        dtype = weight.dtype
        sigma = self._get_sigma(weight, *args, **kwargs)

        if self.version == 1:
            scale = self.scale
        elif self.version == 2:
            scale = F.softplus(self.scale) + self.alpha
        else:
            raise ValueError(f'Unsupported version: {self.version}')

        scale = scale.float() / sigma.float()

        y = weight * scale

        if dtype in (torch.float16, torch.bfloat16):
            y = y.to(dtype)
        return y

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        version_key = f'{prefix}_sn_version'
        if version_key not in state_dict:
            self.version = 1
            state_dict[version_key] = torch.tensor(1)
        return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class _ChunkedSNReweight(nn.Module):
    def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs):
        super().__init__()

        self.num_chunks = num_chunks
        parts = weight.split(weight.shape[0] // num_chunks, dim=0)

        self.parts = nn.ModuleList([
            _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs)
            for p in parts
        ])

    def forward(self, weight: torch.Tensor, *args, **kwargs):
        parts = weight.split(weight.shape[0] // self.num_chunks, dim=0)

        parts = [
            fn(p)
            for fn, p in zip(self.parts, parts)
        ]

        return torch.cat(parts, dim=0)


class _AttnSNReweight(_ChunkedSNReweight):
    def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
        super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs)

        if not renorm_values:
            self.parts[2] = nn.Identity()


def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
                            n_power_iterations: int = 1,
                            eps: float = 1e-6,
                            init_norm_to_current: bool = False,
                            renorm_values: bool = True,
                            renorm_mlp: bool = True,
                            state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
    if isinstance(model, (list, tuple)):
        for i, sub in enumerate(model):
            sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance
            enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps,
                                    init_norm_to_current=init_norm_to_current, renorm_values=renorm_values,
                                    renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd)
        return

    print('Enabling spectral reparametrization')
    args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current)
    visited_prefixes = set()

    def is_guidance_parametrized(name: str):
        if state_dict_guidance is None:
            return True

        p_name = f'{name}.parametrizations'
        is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
        return is_prm

    def parametrize_linear(linear: nn.Linear):
        parametrize.register_parametrization(
            linear,
            'weight',
            _SNReweight(linear.weight, **args)
        )

    for name, mod in model.named_modules():
        pref = '.'.join(name.split('.')[:-1])
        if pref in visited_prefixes:
            continue

        if isinstance(mod, Attention) or name.endswith('.attn'):
            if is_guidance_parametrized(f'{name}.qkv'):
                parametrize.register_parametrization(
                    mod.qkv,
                    'weight',
                    _AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args),
                )
            if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'):
                parametrize_linear(mod.proj)
            visited_prefixes.add(name)
        elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'):
            if is_guidance_parametrized(f'{name}.w12'):
                parametrize.register_parametrization(
                    mod.w12,
                    'weight',
                    _ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args),
                )
            if is_guidance_parametrized(f'{name}.w3'):
                parametrize_linear(mod.w3)
            visited_prefixes.add(name)
        elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name):
            parametrize_linear(mod)


def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
    spectral_reparam = getattr(args, 'spectral_reparam', False)
    if isinstance(spectral_reparam, bool) and spectral_reparam:
        enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance)
    elif isinstance(spectral_reparam, dict):
        enable_spectral_reparam(
            model,
            n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
            eps=spectral_reparam.get('eps', 1e-12),
            init_norm_to_current=True,
            state_dict_guidance=state_dict_guidance,
        )


def disable_spectral_reparam(model: nn.Module):
    print('Disabling spectral reparametrization')
    for name, mod in model.named_modules():
        if parametrize.is_parametrized(mod):
            parametrize.remove_parametrizations(mod, 'weight')
            pass



if __name__ == '__main__':
    import argparse
    from . import radio_model as create_model

    parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
    parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
    parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
    parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
    parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')

    args = parser.parse_args()

    if not args.output:
        chk_dir, chk_name = os.path.split(args.checkpoint)
        args.output = os.path.join(chk_dir, f'clean_{chk_name}')
        print(f'Set output to "{args.output}"')

    chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)

    model = create_model.create_model_from_args(chk['args'])

    key = 'base_model.'
    mod_state = dict()
    extra_state = dict()
    for k, v in chk['state_dict'].items():
        if k.startswith(key):
            mod_state[k[len(key):]] = v
        else:
            extra_state[k] = v

    chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
    if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
        print(chk_load_info)

    if chk['args'].spectral_reparam:
        disable_spectral_reparam(model)

    if hasattr(chk['args'], 'dtype'):
        model.to(dtype=chk['args'].dtype)

    mod_state = model.state_dict()
    final_state = dict()
    final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
    final_state.update(extra_state)

    chk['state_dict'] = final_state
    chk['args'].spectral_reparam = False

    if args.release:
        chk = {
            'arch': chk['arch'],
            'epoch': chk['epoch'],
            'state_dict': chk['state_dict'],
            'args': chk['args'],
        }

    torch.save(chk, args.output)
    pass


================================================
FILE: nit/models/nvidia_radio/radio/eradio_model.py
================================================
#!/usr/bin/env python3

# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

# E-RADIO model from
# Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).

# based on FasterViT, Swin Transformer, YOLOv8

# FasterViT:
# Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).

import timm
import torch
import torch.nn as nn
from timm.models.registry import register_model

from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
import numpy as np
import torch.nn.functional as F
import math
import warnings

#######################
## Codebase from YOLOv8
## BEGINNING
#######################

class C2f(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""
    """From YOLOv8 codebase"""
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        if drop_path is None:
            drop_path = [0.0] * n

        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))

    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

class Bottleneck(nn.Module):
    """Standard bottleneck."""

    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0):  # ch_in, ch_out, shortcut, groups, kernels, expand
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        """'forward()' applies the YOLOv5 FPN to input data."""
        return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))


class Conv(nn.Module):
    """Modified to support layer fusion"""
    default_act = nn.SiLU()  # default activation

    def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
        super().__init__()

        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
        if 1:
            self.bn = torch.nn.BatchNorm2d(b)
            torch.nn.init.constant_(self.bn.weight, bn_weight_init)
            torch.nn.init.constant_(self.bn.bias, 0)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()


    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

    @torch.no_grad()
    def switch_to_deploy(self):
        # return 1
        if not isinstance(self.bn, nn.Identity):
            c, bn = self.conv, self.bn
            w = bn.weight / (bn.running_var + bn.eps) ** 0.5
            w = c.weight * w[:, None, None, None]
            b = bn.bias - bn.running_mean * bn.weight / \
                (bn.running_var + bn.eps)**0.5

            self.conv.weight.data.copy_(w)
            self.conv.bias = nn.Parameter(b)

            self.bn = nn.Identity()

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


#######################
## Codebase from YOLOv8
## END
#######################

def pixel_unshuffle(data, factor=2):
    # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
    B, C, H, W = data.shape
    return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)

class SwiGLU(nn.Module):
    # should be more advanced, but doesnt improve results so far
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x


def window_partition(x, window_size):
    """
    Function for partitioning image into windows and later do windowed attention
    Args:
        x: (B, C, H, W)
        window_size: window size
    Returns:
        windows - local window features (num_windows*B, window_size*window_size, C)
        (Hp, Wp) -  the size of the padded image
    """
    B, C, H, W = x.shape

    if window_size == 0 or (window_size==H and window_size==W):
        windows = x.flatten(2).transpose(1, 2)
        Hp, Wp = H, W
    else:
        pad_h = (window_size - H % window_size) % window_size
        pad_w = (window_size - W % window_size) % window_size
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
        Hp, Wp = H + pad_h, W + pad_w

        x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
        windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)

    return windows, (Hp, Wp)

class Conv2d_BN(nn.Module):
    '''
    Conv2d + BN layer with folding capability to speed up inference
    Can be merged with Conv() function with additional arguments
    '''
    def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
        super().__init__()
        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
        if 1:
            self.bn = torch.nn.BatchNorm2d(b)
            torch.nn.init.constant_(self.bn.weight, bn_weight_init)
            torch.nn.init.constant_(self.bn.bias, 0)

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        return x

    @torch.no_grad()
    def switch_to_deploy(self):
        if not isinstance(self.bn, nn.Identity):
            c, bn = self.conv, self.bn
            w = bn.weight / (bn.running_var + bn.eps) ** 0.5
            w = c.weight * w[:, None, None, None]
            b = bn.bias - bn.running_mean * bn.weight / \
                (bn.running_var + bn.eps)**0.5
            self.conv.weight.data.copy_(w)
            self.conv.bias = nn.Parameter(b)
            self.bn = nn.Identity()



def window_reverse(windows, window_size, H, W, pad_hw):
    """
    Windows to the full feature map
    Args:
        windows: local window features (num_windows*B, window_size, window_size, C)
        window_size: Window size
        H: Height of image
        W: Width of image
        pad_w - a tuple of image passing used in windowing step
    Returns:
        x: (B, C, H, W)

    """
    # print(f"window_reverse, windows.shape {windows.shape}")
    Hp, Wp = pad_hw
    if window_size == 0 or (window_size==H and window_size==W):
        B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
        x = windows.transpose(1, 2).view(B, -1, H, W)
    else:
        B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
        x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
        x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)

        if Hp > H or Wp > W:
            x = x[:, :, :H, :W, ].contiguous()

    return x



class PosEmbMLPSwinv2D(nn.Module):
    """
    2D positional embedding from Swin Transformer v2
    Added functionality to store the positional embedding in the model and not recompute it every time
    """
    def __init__(
        self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,
    ):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        # mlp to generate continuous relative position bias
        self.cpb_mlp = nn.Sequential(
            nn.Linear(2, cpb_mlp_hidden, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(cpb_mlp_hidden, num_heads, bias=False),
        )

        self.grid_exists = False
        self.seq_length = seq_length
        self.deploy = False
        self.num_heads = num_heads
        self.no_log = no_log
        self.pretrained_window_size = pretrained_window_size
        self.relative_bias_window_size = window_size

        relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,
                                                                                                     pretrained_window_size, seq_length,
                                                                                                     no_log)

        self.register_buffer("relative_coords_table", relative_coords_table)
        self.register_buffer("relative_position_index", relative_position_index)
        self.register_buffer("relative_bias", relative_bias)  # for EMA

    def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
        # as in separate function to support window size chage after model weights loading
        relative_coords_h = torch.arange(
            -(window_size[0] - 1), window_size[0], dtype=torch.float32
        )
        relative_coords_w = torch.arange(
            -(window_size[1] - 1), window_size[1], dtype=torch.float32
        )
        relative_coords_table = (
            torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
            .permute(1, 2, 0)
            .contiguous()
            .unsqueeze(0)
        )  # 1, 2*Wh-1, 2*Ww-1, 2
        if pretrained_window_size[0] > 0:
            relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
            relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
        else:
            relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
            relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1

        if not no_log:
            relative_coords_table *= 8  # normalize to -8, 8
            relative_coords_table = (
                torch.sign(relative_coords_table)
                * torch.log2(torch.abs(relative_coords_table) + 1.0)
                / np.log2(8)
            )

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = (
            coords_flatten[:, :, None] - coords_flatten[:, None, :]
        )  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(
            1, 2, 0
        ).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww

        relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)

        self.relative_bias_window_size = window_size

        return relative_coords_table, relative_position_index, relative_bias


    def switch_to_deploy(self):
        self.deploy = True
        self.grid_exists = True

    def forward(self, input_tensor):
        # for efficiency, we want this forward to be folded into a single operation (sum)
        # if resolution stays the same, then we dont need to recompute MLP layers

        if not self.deploy or self.training:
            self.grid_exists = False

        #compare if all elements in self.window_size list match those in self.relative_bias_window_size
        if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):
            relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,
                                                                                                        self.pretrained_window_size, self.seq_length,
                                                                                                        self.no_log)

            self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)
            self.relative_position_index = relative_position_index.to(self.relative_position_index.device)
            self.relative_bias = relative_bias.to(self.relative_bias.device)

        if self.deploy and self.grid_exists:
            input_tensor = input_tensor + self.relative_bias
            return input_tensor

        if 1:
            self.grid_exists = True

            relative_position_bias_table = self.cpb_mlp(
                self.relative_coords_table
            ).view(-1, self.num_heads)
            relative_position_bias = relative_position_bias_table[
                self.relative_position_index.view(-1)
            ].view(
                self.window_size[0] * self.window_size[1],
                self.window_size[0] * self.window_size[1],
                -1,
            )  # Wh*Ww,Wh*Ww,nH

            relative_position_bias = relative_position_bias.permute(
                2, 0, 1
            ).contiguous()  # nH, Wh*Ww, Wh*Ww
            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)

            self.relative_bias = relative_position_bias.unsqueeze(0)

        input_tensor = input_tensor + self.relative_bias
        return input_tensor


class GRAAttentionBlock(nn.Module):
    def __init__(self, window_size, dim_in, dim_out,
                 num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
                 norm_layer=nn.LayerNorm, layer_scale=None,
                  use_swiglu=True,
                  subsample_ratio=1, dim_ratio=1, conv_base=False,
                  do_windowing=True, multi_query=False, use_shift=0,
                  cpb_mlp_hidden=512, conv_groups_ratio=0):
        '''
        Global Resolution Attention Block , see README for details
        Attention with subsampling to get a bigger receptive field for attention
        conv_base - u
Download .txt
gitextract_jijj82e1/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── c2i/
│   │   ├── nit_b_pack_merge_radio_65536.yaml
│   │   ├── nit_l_pack_merge_radio_16384.yaml
│   │   ├── nit_s_pack_merge_radio_65536.yaml
│   │   ├── nit_xl_pack_merge_radio_16384.yaml
│   │   └── nit_xxl_pack_merge_radio_8192.yaml
│   └── preprocess/
│       ├── imagenet1k_256x256.yaml
│       ├── imagenet1k_512x512.yaml
│       └── imagenet1k_native_resolution.yaml
├── nit/
│   ├── data/
│   │   ├── pack/
│   │   │   ├── __init__.py
│   │   │   ├── ennlshp.py
│   │   │   ├── lpfhp.py
│   │   │   ├── nnlshp.py
│   │   │   └── spfhp.py
│   │   ├── packed_c2i_data.py
│   │   └── sampler_util.py
│   ├── models/
│   │   ├── c2i/
│   │   │   └── nit_model.py
│   │   ├── nvidia_radio/
│   │   │   ├── hubconf.py
│   │   │   └── radio/
│   │   │       ├── __init__.py
│   │   │       ├── adaptor_base.py
│   │   │       ├── adaptor_generic.py
│   │   │       ├── adaptor_mlp.py
│   │   │       ├── adaptor_registry.py
│   │   │       ├── block.py
│   │   │       ├── cls_token.py
│   │   │       ├── common.py
│   │   │       ├── conv.py
│   │   │       ├── dinov2_arch.py
│   │   │       ├── dual_hybrid_vit.py
│   │   │       ├── enable_cpe_support.py
│   │   │       ├── enable_damp.py
│   │   │       ├── enable_spectral_reparam.py
│   │   │       ├── eradio_model.py
│   │   │       ├── extra_models.py
│   │   │       ├── extra_timm_models.py
│   │   │       ├── feature_normalizer.py
│   │   │       ├── forward_intermediates.py
│   │   │       ├── hf_model.py
│   │   │       ├── input_conditioner.py
│   │   │       ├── open_clip_adaptor.py
│   │   │       ├── radio_model.py
│   │   │       ├── vision_transformer_xpos.py
│   │   │       ├── vit_patch_generator.py
│   │   │       └── vitdet.py
│   │   └── utils/
│   │       ├── convs.py
│   │       ├── funcs.py
│   │       ├── norms.py
│   │       └── pos_embeds/
│   │           ├── flash_attn_rotary.py
│   │           ├── rope.py
│   │           └── sincos.py
│   ├── schedulers/
│   │   └── flow_matching/
│   │       ├── loss.py
│   │       └── samplers_c2i.py
│   └── utils/
│       ├── __init__.py
│       ├── deepspeed_zero_to_fp32.py
│       ├── ema.py
│       ├── eval_utils.py
│       ├── freeze.py
│       ├── gpu_memory_monitor.py
│       ├── lr_scheduler.py
│       ├── misc_utils.py
│       ├── model_utils.py
│       ├── train_utils.py
│       ├── util.py
│       ├── video_utils.py
│       └── warp_pos_idx.py
├── projects/
│   ├── evaluate/
│   │   └── adm_evaluator.py
│   ├── preprocess/
│   │   ├── image_latent_c2i.py
│   │   └── image_nr_latent_c2i.py
│   ├── sample/
│   │   └── sample_c2i_ddp.py
│   └── train/
│       └── packed_trainer_c2i.py
├── requirements.txt
├── scripts/
│   ├── preprocess/
│   │   ├── preorocess_in1k_256x256.sh
│   │   ├── preorocess_in1k_512x512.sh
│   │   └── preorocess_in1k_native_resolution.sh
│   ├── sample/
│   │   ├── sample_256x256.sh
│   │   ├── sample_512x512.sh
│   │   └── sample_768x768.sh
│   └── train/
│       ├── train_b_model.sh
│       ├── train_l_model.sh
│       ├── train_s_model.sh
│       ├── train_xl_model.sh
│       └── train_xxl_model.sh
├── setup.py
└── tools/
    ├── download_dataset_256x256.sh
    ├── download_dataset_512x512.sh
    ├── download_dataset_data_meta.sh
    ├── download_dataset_native_resolution.sh
    ├── download_dataset_sampler_meta.sh
    └── pack_dataset.py
Download .txt
SYMBOL INDEX (746 symbols across 60 files)

FILE: nit/data/pack/__init__.py
  function get_strategy (line 11) | def get_strategy(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_l...
  function pack_dataset (line 30) | def pack_dataset(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_l...

FILE: nit/data/pack/ennlshp.py
  function get_packing_matrix (line 10) | def get_packing_matrix(strategy_set, max_sequence_length):
  function get_packing_strategies (line 20) | def get_packing_strategies(start_length, minimum_increment, target_lengt...
  function ENNLSHP (line 41) | def ENNLSHP(histogram, max_sequence_length, max_sequences_per_pack):

FILE: nit/data/pack/lpfhp.py
  function add_pack (line 9) | def add_pack(pack, count, tmp, final, limit, offset, max_sequence_length...
  function LPFHP (line 21) | def LPFHP(histogram, max_sequence_length, max_sequences_per_pack, distri...

FILE: nit/data/pack/nnlshp.py
  function get_packing_matrix (line 10) | def get_packing_matrix(strategy_set, max_sequence_length):
  function get_packing_strategies (line 20) | def get_packing_strategies(start_length, minimum_increment, target_lengt...
  function NNLSHP (line 41) | def NNLSHP(histogram, max_sequence_length, max_sequences_per_pack):

FILE: nit/data/pack/spfhp.py
  function add_pack (line 9) | def add_pack(pack, count, tmp, final, limit, offset):
  function SPFHP (line 17) | def SPFHP(histogram, max_sequence_length, max_sequences_per_pack):

FILE: nit/data/packed_c2i_data.py
  function resize_arr (line 29) | def resize_arr(pil_image, height, width):
  function center_crop_arr (line 34) | def center_crop_arr(pil_image, image_size):
  function packed_collate_fn (line 54) | def packed_collate_fn(batch):
  class ImprovedPackedImageNetLatentDataset (line 75) | class ImprovedPackedImageNetLatentDataset(Dataset):
    method __init__ (line 76) | def __init__(self, packed_json, jsonl_dir, data_types, latent_dirs, im...
    method __len__ (line 91) | def __len__(self):
    method __getitem__ (line 94) | def __getitem__(self, index):
  class C2ILoader (line 127) | class C2ILoader():
    method __init__ (line 128) | def __init__(self, data_config):
    method train_len (line 148) | def train_len(self):
    method train_dataloader (line 151) | def train_dataloader(self, rank, world_size, global_batch_size, max_st...
    method test_dataloader (line 175) | def test_dataloader(self):
    method val_dataloader (line 178) | def val_dataloader(self):

FILE: nit/data/sampler_util.py
  function get_train_sampler (line 5) | def get_train_sampler(dataset, rank, world_size, global_batch_size, max_...
  function get_packed_batch_sampler (line 30) | def get_packed_batch_sampler(

FILE: nit/models/c2i/nit_model.py
  function modulate (line 16) | def modulate(x, shift, scale):
  function build_mlp (line 19) | def build_mlp(hidden_size, projector_dim, z_dim):
  class TimestepEmbedder (line 30) | class TimestepEmbedder(nn.Module):
    method __init__ (line 34) | def __init__(self, hidden_size, frequency_embedding_size=256):
    method positional_embedding (line 44) | def positional_embedding(t, dim, max_period=10000):
    method forward (line 64) | def forward(self, t):
  class LabelEmbedder (line 71) | class LabelEmbedder(nn.Module):
    method __init__ (line 75) | def __init__(self, num_classes, hidden_size, dropout_prob):
    method forward (line 82) | def forward(self, labels):
  class Attention (line 91) | class Attention(nn.Module):
    method __init__ (line 92) | def __init__(
    method forward (line 115) | def forward(self, x: torch.Tensor, cu_seqlens, freqs_cos, freqs_sin) -...
  class NiTBlock (line 142) | class NiTBlock(nn.Module):
    method __init__ (line 146) | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwar...
    method forward (line 172) | def forward(self, x, c, cu_seqlens, freqs_cos, freqs_sin):
  class FinalLayer (line 182) | class FinalLayer(nn.Module):
    method __init__ (line 186) | def __init__(self, hidden_size, patch_size, out_channels):
    method forward (line 195) | def forward(self, x, c):
  class NiT (line 203) | class NiT(nn.Module):
    method __init__ (line 207) | def __init__(
    method initialize_weights (line 258) | def initialize_weights(self):
    method unpatchify (line 291) | def unpatchify(self, x, patch_size=None):
    method get_rope (line 306) | def get_rope(self, hw_list):
    method forward (line 318) | def forward(self, x, t, y, hw_list, return_zs=False, return_logvar=Fal...
    method ckpt_wrapper (line 364) | def ckpt_wrapper(self, module):
    method dtype (line 371) | def dtype(self) -> torch.dtype:

FILE: nit/models/nvidia_radio/hubconf.py
  function radio_model (line 30) | def radio_model(
  function get_prefix_state_dict (line 188) | def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):

FILE: nit/models/nvidia_radio/radio/adaptor_base.py
  class AdaptorInput (line 16) | class AdaptorInput(NamedTuple):
  class RadioOutput (line 24) | class RadioOutput(NamedTuple):
    method to (line 28) | def to(self, *args, **kwargs):
  class AdaptorBase (line 35) | class AdaptorBase(nn.Module):
    method forward (line 36) | def forward(self, input: AdaptorInput) -> RadioOutput:

FILE: nit/models/nvidia_radio/radio/adaptor_generic.py
  class GenericAdaptor (line 18) | class GenericAdaptor(AdaptorBase):
    method __init__ (line 19) | def __init__(self, main_config: Namespace, adaptor_config, state, mlp_...
    method forward (line 58) | def forward(self, input: AdaptorInput) -> RadioOutput:

FILE: nit/models/nvidia_radio/radio/adaptor_mlp.py
  class MLP (line 20) | class MLP(nn.Module):
    method __init__ (line 21) | def __init__(self, input_size: int, hidden_size: int, output_size: int,
    method forward (line 42) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class MLP2 (line 51) | class MLP2(nn.Module):
    method __init__ (line 52) | def __init__(self, input_size: int, hidden_size: int, output_size: int,
    method forward (line 91) | def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = No...
  function strip_prefix (line 117) | def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
  function get_mlp_info_from_state (line 126) | def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor]...
  function create_mlp_from_config (line 153) | def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int...
  function create_mlp_from_state (line 159) | def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], ...

FILE: nit/models/nvidia_radio/radio/adaptor_registry.py
  class AdaptorRegistry (line 19) | class AdaptorRegistry:
    method __init__ (line 20) | def __init__(self):
    method register_adaptor (line 23) | def register_adaptor(self, name):
    method create_adaptor (line 31) | def create_adaptor(self, name, main_config: Namespace, adaptor_config:...

FILE: nit/models/nvidia_radio/radio/block.py
  class C2f (line 15) | class C2f(nn.Module):
    method __init__ (line 18) | def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=...
    method forward (line 28) | def forward(self, x):
    method forward_split (line 34) | def forward_split(self, x):
  class Bottleneck (line 41) | class Bottleneck(nn.Module):
    method __init__ (line 44) | def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_p...
    method forward (line 52) | def forward(self, x):

FILE: nit/models/nvidia_radio/radio/cls_token.py
  class ClsToken (line 14) | class ClsToken(nn.Module):
    method __init__ (line 15) | def __init__(self, ndim: int,
    method disable (line 40) | def disable(self):
    method forward (line 44) | def forward(self, x: torch.Tensor):
    method no_weight_decay (line 56) | def no_weight_decay(self):

FILE: nit/models/nvidia_radio/radio/common.py
  class RadioResource (line 16) | class RadioResource:

FILE: nit/models/nvidia_radio/radio/conv.py
  function autopad (line 16) | def autopad(k, p=None, d=1):  # kernel, padding, dilation
  class Conv (line 25) | class Conv(nn.Module):
    method __init__ (line 28) | def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, d...
    method forward (line 39) | def forward(self,x):
    method switch_to_deploy (line 46) | def switch_to_deploy(self):

FILE: nit/models/nvidia_radio/radio/dinov2_arch.py
  function make_2tuple (line 44) | def make_2tuple(x):
  class PatchEmbed (line 53) | class PatchEmbed(nn.Module):
    method __init__ (line 65) | def __init__(
    method forward (line 96) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method flops (line 111) | def flops(self) -> float:
  class Attention (line 119) | class Attention(nn.Module):
    method __init__ (line 120) | def __init__(
    method forward (line 139) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class MemEffAttention (line 165) | class MemEffAttention(Attention):
    method forward (line 166) | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
  class Mlp (line 185) | class Mlp(nn.Module):
    method __init__ (line 186) | def __init__(
    method forward (line 203) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class SwiGLUFFN (line 212) | class SwiGLUFFN(nn.Module):
    method __init__ (line 213) | def __init__(
    method forward (line 228) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class SwiGLUFFNFused (line 239) | class SwiGLUFFNFused(SwiGLU):
    method __init__ (line 240) | def __init__(
  function drop_path (line 260) | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
  class DropPath (line 272) | class DropPath(nn.Module):
    method __init__ (line 275) | def __init__(self, drop_prob=None):
    method forward (line 279) | def forward(self, x):
  class LayerScale (line 283) | class LayerScale(nn.Module):
    method __init__ (line 284) | def __init__(
    method forward (line 294) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method _load_from_state_dict (line 297) | def _load_from_state_dict(self, state_dict, prefix, local_metadata, st...
  class Block (line 322) | class Block(nn.Module):
    method __init__ (line 323) | def __init__(
    method forward (line 368) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class NestedTensorBlock (line 396) | class NestedTensorBlock(Block):
    method forward_nested (line 397) | def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Ten...
    method forward (line 437) | def forward(self, x_or_x_list):
  function drop_add_residual_stochastic_depth (line 448) | def drop_add_residual_stochastic_depth(
  function get_branges_scales (line 472) | def get_branges_scales(x, sample_drop_ratio=0.0):
  function add_residual (line 480) | def add_residual(x, brange, residual, residual_scale_factor, scaling_vec...
  function get_attn_bias_and_cat (line 495) | def get_attn_bias_and_cat(x_list, branges=None):
  function drop_add_residual_stochastic_depth_list (line 519) | def drop_add_residual_stochastic_depth_list(
  function named_apply (line 542) | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=Tr...
  class BlockChunk (line 553) | class BlockChunk(nn.ModuleList):
    method forward (line 554) | def forward(self, x):
  class DinoVisionTransformer (line 560) | class DinoVisionTransformer(nn.Module):
    method __init__ (line 561) | def __init__(
    method interpolate_pos_encoding (line 682) | def interpolate_pos_encoding(self, x, w, h):
    method prepare_tokens_with_masks (line 716) | def prepare_tokens_with_masks(self, x, masks=None):
    method forward_features_list (line 737) | def forward_features_list(self, x_list, masks_list):
    method forward_features (line 757) | def forward_features(self, x, masks=None):
    method _get_intermediate_layers_not_chunked (line 775) | def _get_intermediate_layers_not_chunked(self, x, n=1):
    method _get_intermediate_layers_chunked (line 787) | def _get_intermediate_layers_chunked(self, x, n=1):
    method get_intermediate_layers (line 801) | def get_intermediate_layers(
    method forward (line 827) | def forward(self, *args, is_training=False, **kwargs):
  function vit_small (line 835) | def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_base (line 849) | def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_large (line 863) | def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
  function vit_giant2 (line 877) | def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
  class Weights (line 894) | class Weights(Enum):
  function _make_dinov2_model (line 898) | def _make_dinov2_model(
  function dinov2_vits14 (line 934) | def dinov2_vits14(**kwargs):
  function dinov2_vitb14 (line 941) | def dinov2_vitb14(**kwargs):
  function dinov2_vitl14 (line 948) | def dinov2_vitl14(**kwargs):
  function dinov2_vitg14 (line 955) | def dinov2_vitg14(**kwargs):
  function dinov2_vits14_reg (line 966) | def dinov2_vits14_reg(**kwargs):
  function dinov2_vitb14_reg (line 979) | def dinov2_vitb14_reg(**kwargs):
  function dinov2_vitl14_reg (line 992) | def dinov2_vitl14_reg(**kwargs):
  function dinov2_vitg14_reg (line 1005) | def dinov2_vitg14_reg(**kwargs):

FILE: nit/models/nvidia_radio/radio/dual_hybrid_vit.py
  class Fuser (line 17) | class Fuser(nn.Module):
    method __init__ (line 18) | def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
    method forward (line 30) | def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
  class AttnDownsample (line 65) | class AttnDownsample(nn.Module):
    method __init__ (line 66) | def __init__(self, dim: int, window_size: int, num_heads: int = 16):
    method forward (line 76) | def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> tor...
  class HybridModel (line 108) | class HybridModel(nn.Module):
    method __init__ (line 109) | def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, ...
    method patch_size (line 143) | def patch_size(self):
    method no_fsdp_wrap_types (line 147) | def no_fsdp_wrap_types(self):
    method forward (line 150) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method forward_features (line 153) | def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  function hybrid_base (line 190) | def hybrid_base(pretrained=False, concatenate: bool = False, weight_init...
  function hybrid_large (line 199) | def hybrid_large(pretrained=False, concatenate: bool = False, weight_ini...
  function hybrid_huge (line 208) | def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init...

FILE: nit/models/nvidia_radio/radio/enable_cpe_support.py
  function _attn_forward_pack (line 27) | def _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: tor...
  function _block_forward_pack (line 42) | def _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch....
  function _forward_cpe_pack (line 47) | def _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor...
  function _forward_cpe (line 73) | def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
  function _take_indices (line 83) | def _take_indices(
  function _forward_intermediates_cpe (line 95) | def _forward_intermediates_cpe(
  function _forward_cpe_dinov2 (line 112) | def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Ten...
  function _forward_intermediates_cpe_dinov2 (line 118) | def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):
  function _enable_cpe_for_timm_vit (line 122) | def _enable_cpe_for_timm_vit(model: VisionTransformer,
  function _enable_cpe_for_dv2_reg_vit (line 172) | def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,
  function enable_cpe (line 213) | def enable_cpe(model: nn.Module,

FILE: nit/models/nvidia_radio/radio/enable_damp.py
  class DAMP (line 22) | class DAMP(nn.Identity):
    method __init__ (line 23) | def __init__(self, std: float):
  function enable_damp (line 28) | def enable_damp(model: nn.Module, std: float):
  function configure_damp_from_args (line 39) | def configure_damp_from_args(model: nn.Module, args):

FILE: nit/models/nvidia_radio/radio/enable_spectral_reparam.py
  class _SNReweight (line 26) | class _SNReweight(_SpectralNorm):
    method __init__ (line 27) | def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: ...
    method _get_sigma (line 56) | def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = N...
    method forward (line 76) | def forward(self, weight: torch.Tensor, *args, **kwargs):
    method _load_from_state_dict (line 95) | def _load_from_state_dict(self, state_dict, prefix, local_metadata, st...
  class _ChunkedSNReweight (line 103) | class _ChunkedSNReweight(nn.Module):
    method __init__ (line 104) | def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_...
    method forward (line 115) | def forward(self, weight: torch.Tensor, *args, **kwargs):
  class _AttnSNReweight (line 126) | class _AttnSNReweight(_ChunkedSNReweight):
    method __init__ (line 127) | def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: ...
  function enable_spectral_reparam (line 134) | def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
  function configure_spectral_reparam_from_args (line 197) | def configure_spectral_reparam_from_args(model: nn.Module, args, state_d...
  function disable_spectral_reparam (line 211) | def disable_spectral_reparam(model: nn.Module):

FILE: nit/models/nvidia_radio/radio/eradio_model.py
  class C2f (line 35) | class C2f(nn.Module):
    method __init__ (line 38) | def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=...
    method forward (line 48) | def forward(self, x):
    method forward_split (line 54) | def forward_split(self, x):
  class Bottleneck (line 60) | class Bottleneck(nn.Module):
    method __init__ (line 63) | def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_p...
    method forward (line 71) | def forward(self, x):
  class Conv (line 76) | class Conv(nn.Module):
    method __init__ (line 80) | def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, d...
    method forward (line 91) | def forward(self,x):
    method switch_to_deploy (line 98) | def switch_to_deploy(self):
  function autopad (line 112) | def autopad(k, p=None, d=1):  # kernel, padding, dilation
  function pixel_unshuffle (line 126) | def pixel_unshuffle(data, factor=2):
  class SwiGLU (line 131) | class SwiGLU(nn.Module):
    method forward (line 133) | def forward(self, x):
  function window_partition (line 138) | def window_partition(x, window_size):
  class Conv2d_BN (line 165) | class Conv2d_BN(nn.Module):
    method __init__ (line 170) | def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=...
    method forward (line 178) | def forward(self,x):
    method switch_to_deploy (line 184) | def switch_to_deploy(self):
  function window_reverse (line 197) | def window_reverse(windows, window_size, H, W, pad_hw):
  class PosEmbMLPSwinv2D (line 227) | class PosEmbMLPSwinv2D(nn.Module):
    method __init__ (line 232) | def __init__(
    method relative_bias_initialization (line 261) | def relative_bias_initialization(self, window_size, num_heads, pretrai...
    method switch_to_deploy (line 313) | def switch_to_deploy(self):
    method forward (line 317) | def forward(self, input_tensor):
  class GRAAttentionBlock (line 363) | class GRAAttentionBlock(nn.Module):
    method __init__ (line 364) | def __init__(self, window_size, dim_in, dim_out,
    method forward (line 445) | def forward(self, x):
  class MultiResolutionAttention (line 522) | class MultiResolutionAttention(nn.Module):
    method __init__ (line 530) | def __init__(self, window_size, sr_ratio,
    method forward (line 568) | def forward(self, x):
  class Mlp (line 577) | class Mlp(nn.Module):
    method __init__ (line 582) | def __init__(self,
    method forward (line 605) | def forward(self, x):
  class Downsample (line 614) | class Downsample(nn.Module):
    method __init__ (line 620) | def __init__(self,
    method forward (line 646) | def forward(self, x):
  class PatchEmbed (line 652) | class PatchEmbed(nn.Module):
    method __init__ (line 658) | def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
    method forward (line 683) | def forward(self, x):
  class ConvBlock (line 690) | class ConvBlock(nn.Module):
    method __init__ (line 696) | def __init__(self, dim,
    method forward (line 716) | def forward(self, x):
  class WindowAttention (line 729) | class WindowAttention(nn.Module):
    method __init__ (line 733) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, re...
    method forward (line 762) | def forward(self, x, attn_mask = None):
  class ERADIOLayer (line 793) | class ERADIOLayer(nn.Module):
    method __init__ (line 798) | def __init__(self,
    method forward (line 905) | def forward(self, x):
  class InterpolateLayer (line 949) | class InterpolateLayer(nn.Module):
    method __init__ (line 950) | def __init__(self, size=None, scale_factor=None, mode='nearest'):
    method forward (line 956) | def forward(self, x):
  class HiResNeck (line 960) | class HiResNeck(nn.Module):
    method __init__ (line 965) | def __init__(self, dim, depths, neck_start_stage, full_features_head_d...
    method forward (line 1006) | def forward(self, x, il_level=-1, full_features=None):
  class ERADIO (line 1020) | class ERADIO(nn.Module):
    method __init__ (line 1025) | def __init__(self,
    method _init_weights (line 1145) | def _init_weights(self, m):
    method no_weight_decay_keywords (line 1161) | def no_weight_decay_keywords(self):
    method forward_features (line 1164) | def forward_features(self, x):
    method forward (line 1184) | def forward(self, x):
    method switch_to_deploy (line 1195) | def switch_to_deploy(self):
    method change_window_size (line 1209) | def change_window_size(self, new_window_size):
    method set_optimal_window_size (line 1242) | def set_optimal_window_size(self, image_dim, max_window_size = 16):
  function eradio_large_fullres_ws16 (line 1293) | def eradio_large_fullres_ws16(pretrained=False, **kwargs):
  function eradio_xxxtiny (line 1318) | def eradio_xxxtiny(pretrained=False, **kwargs):  # ,
  function eradio_xxxtiny_8x_ws12 (line 1342) | def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
  function eradio_xxxtiny_8x_ws16 (line 1367) | def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
  function eradio (line 1391) | def eradio(pretrained=False, **kwargs):

FILE: nit/models/nvidia_radio/radio/extra_models.py
  class PaliGemmaWrapper (line 19) | class PaliGemmaWrapper(nn.Module):
    method __init__ (line 20) | def __init__(self, vis_model: nn.Module, embed_dim: int):
    method patch_size (line 27) | def patch_size(self):
    method blocks (line 31) | def blocks(self):
    method embed_dim (line 35) | def embed_dim(self):
    method forward (line 38) | def forward(self, x: torch.Tensor):
    method forward_features (line 51) | def forward_features(self, x: torch.Tensor):
  function _get_paligemma_model (line 55) | def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch....
  function paligemma_896_student (line 77) | def paligemma_896_student(**kwargs):
  function dv2_sdpa (line 83) | def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
  function _load_dino_v2 (line 99) | def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretra...
  class DinoWrapper (line 116) | class DinoWrapper(nn.Module):
    method __init__ (line 117) | def __init__(self, dino_model: nn.Module):
    method embed_dim (line 124) | def embed_dim(self):
    method patch_size (line 128) | def patch_size(self):
    method num_cls_tokens (line 132) | def num_cls_tokens(self):
    method num_registers (line 136) | def num_registers(self):
    method num_summary_tokens (line 140) | def num_summary_tokens(self):
    method blocks (line 144) | def blocks(self):
    method forward (line 147) | def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    method forward_features (line 155) | def forward_features(self, x: torch.Tensor):
    method patchify (line 162) | def patchify(self, x: torch.Tensor) -> torch.Tensor:
    method forward_intermediates (line 165) | def forward_intermediates(self,
  function _dino_student (line 181) | def _dino_student(arch: str, **kwargs):
  function dino_v2_l_student (line 201) | def dino_v2_l_student(**kwargs):
  function dino_v2_g_student (line 205) | def dino_v2_g_student(**kwargs):

FILE: nit/models/nvidia_radio/radio/extra_timm_models.py
  function vit_tiny_patch14_224 (line 30) | def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  function vit_small_patch14_224 (line 39) | def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  function vit_base_patch14_224 (line 48) | def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  function vit_base_patch16_v2_224 (line 58) | def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransfo...
  function vit_large_patch16_v2_224 (line 72) | def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> Visi...
  function vit_huge_patch16_224 (line 86) | def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  function vit_huge_patch16_224_mlpnorm (line 99) | def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTr...
  function vit_giant_patch16_224 (line 112) | def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **k...
  function vit_bigG_patch14_224 (line 123) | def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  function _create_vision_transformer (line 129) | def _create_vision_transformer(*args, **kwargs):
  function _patch_layer_scale (line 135) | def _patch_layer_scale(model: VisionTransformer):
  class ScaledLayerNorm (line 152) | class ScaledLayerNorm(nn.LayerNorm):
    method __init__ (line 156) | def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
    method forward (line 161) | def forward(self, x):
  class DyT (line 167) | class DyT(nn.Module):
    method __init__ (line 168) | def __init__(self, C: int, init_alpha: float):
    method forward (line 174) | def forward(self, x: torch.Tensor):
  function vit_large_dyt_patch16_224 (line 179) | def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> Vis...
  function _apply_scaled_ln (line 193) | def _apply_scaled_ln(model: VisionTransformer):
  function _replace_ln (line 198) | def _replace_ln(model: VisionTransformer, fn):

FILE: nit/models/nvidia_radio/radio/feature_normalizer.py
  function _run_kernel (line 14) | def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):
  class FeatureNormalizer (line 27) | class FeatureNormalizer(nn.Module):
    method __init__ (line 28) | def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
    method forward (line 34) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class InterFeatState (line 39) | class InterFeatState(NamedTuple):
  class IntermediateFeatureNormalizerBase (line 44) | class IntermediateFeatureNormalizerBase(nn.Module):
    method forward (line 45) | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, ...
  class IntermediateFeatureNormalizer (line 49) | class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
    method __init__ (line 50) | def __init__(self, num_intermediates: int, embed_dim: int, rot_per_lay...
    method forward (line 61) | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, ...
    method _get_rotation (line 89) | def _get_rotation(self, rot_index: int) -> torch.Tensor:
  class NullIntermediateFeatureNormalizer (line 95) | class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
    method __init__ (line 98) | def __init__(self, dtype: torch.dtype, device: torch.device):
    method get_instance (line 103) | def get_instance(dtype: torch.dtype, device: torch.device):
    method forward (line 110) | def forward(self, x: torch.Tensor, index: int, rot_index: int = None, ...

FILE: nit/models/nvidia_radio/radio/forward_intermediates.py
  function _take_indices (line 18) | def _take_indices(
  function forward_intermediates (line 30) | def forward_intermediates(

FILE: nit/models/nvidia_radio/radio/hf_model.py
  class RADIOConfig (line 49) | class RADIOConfig(PretrainedConfig):
    method __init__ (line 52) | def __init__(
  class RADIOModel (line 89) | class RADIOModel(PreTrainedModel):
    method __init__ (line 98) | def __init__(self, config: RADIOConfig):
    method adaptors (line 157) | def adaptors(self) -> nn.ModuleDict:
    method model (line 161) | def model(self) -> VisionTransformer:
    method input_conditioner (line 165) | def input_conditioner(self) -> InputConditioner:
    method num_summary_tokens (line 169) | def num_summary_tokens(self) -> int:
    method patch_size (line 173) | def patch_size(self) -> int:
    method max_resolution (line 177) | def max_resolution(self) -> int:
    method preferred_resolution (line 181) | def preferred_resolution(self) -> Resolution:
    method window_size (line 185) | def window_size(self) -> int:
    method min_resolution_step (line 189) | def min_resolution_step(self) -> int:
    method make_preprocessor_external (line 192) | def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch...
    method get_nearest_supported_resolution (line 195) | def get_nearest_supported_resolution(self, height: int, width: int) ->...
    method switch_to_deploy (line 198) | def switch_to_deploy(self):
    method forward (line 201) | def forward(self, x: torch.Tensor):

FILE: nit/models/nvidia_radio/radio/input_conditioner.py
  class InputConditioner (line 17) | class InputConditioner(nn.Module):
    method __init__ (line 18) | def __init__(self,
    method forward (line 31) | def forward(self, x: torch.Tensor):
  function get_default_conditioner (line 38) | def get_default_conditioner():
  function _to_tensor (line 48) | def _to_tensor(v: norm_t):

FILE: nit/models/nvidia_radio/radio/open_clip_adaptor.py
  class OpenCLIP_RADIO (line 19) | class OpenCLIP_RADIO(GenericAdaptor):
    method __init__ (line 20) | def __init__(self, main_config: Namespace, adaptor_config: dict_t, sta...
    method encode_text (line 35) | def encode_text(self, text, normalize: bool = False):
  function create_open_clip_adaptor (line 40) | def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dic...

FILE: nit/models/nvidia_radio/radio/radio_model.py
  class Resolution (line 25) | class Resolution(NamedTuple):
  class RADIOModel (line 30) | class RADIOModel(nn.Module):
    method __init__ (line 31) | def __init__(
    method num_summary_tokens (line 67) | def num_summary_tokens(self) -> int:
    method num_cls_tokens (line 79) | def num_cls_tokens(self) -> int:
    method patch_size (line 91) | def patch_size(self) -> int:
    method max_resolution (line 102) | def max_resolution(self) -> int:
    method preferred_resolution (line 106) | def preferred_resolution(self) -> Resolution:
    method window_size (line 110) | def window_size(self) -> int:
    method min_resolution_step (line 114) | def min_resolution_step(self) -> int:
    method blocks (line 121) | def blocks(self) -> Iterable[nn.Module]:
    method embed_dim (line 128) | def embed_dim(self) -> int:
    method make_preprocessor_external (line 131) | def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch...
    method get_nearest_supported_resolution (line 136) | def get_nearest_supported_resolution(self, height: int, width: int) ->...
    method switch_to_deploy (line 145) | def switch_to_deploy(self):
    method forward (line 150) | def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[...
    method forward_pack (line 169) | def forward_pack(self, x: List[torch.Tensor], feature_fmt: str = 'NLC'...
    method _extract_final (line 198) | def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt...
    method forward_intermediates (line 264) | def forward_intermediates(
  function create_model_from_args (line 332) | def create_model_from_args(args) -> nn.Module:

FILE: nit/models/nvidia_radio/radio/vision_transformer_xpos.py
  function _get_init_scale (line 15) | def _get_init_scale(num_encoder_layers: int, num_decoder_layers: int, is...
  function duplicate_interleave (line 32) | def duplicate_interleave(m):
  function rotate_every_two (line 36) | def rotate_every_two(x):
  class XPosEmbedding2D (line 43) | class XPosEmbedding2D(torch.nn.Module):
    method __init__ (line 49) | def __init__(
    method cos_sin (line 70) | def cos_sin(
    method forward (line 107) | def forward(self, q: torch.Tensor, k: torch.Tensor, token_shape: Tuple...
  class MagnetoAttention (line 120) | class MagnetoAttention(nn.Module):
    method __init__ (line 121) | def __init__(self, d_model: int, n_head: int, pos_emb: XPosEmbedding2D):
    method forward (line 134) | def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape...
    method _reset_parameters (line 165) | def _reset_parameters(self):
  class MagnetoTransformerEncoderLayer (line 173) | class MagnetoTransformerEncoderLayer(nn.Module):
    method __init__ (line 174) | def __init__(self, d_model: int, nhead: int, pos_emb: XPosEmbedding2D,
    method initialize (line 195) | def initialize(self):
    method forward (line 205) | def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape...
    method _sa_block (line 210) | def _sa_block(self, x: torch.Tensor, num_prefix_tokens: int, patch_sha...
    method _ff_block (line 214) | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
  class VisionTransformer (line 223) | class VisionTransformer(nn.Module):
    method __init__ (line 231) | def __init__(
    method num_prefix_tokens (line 283) | def num_prefix_tokens(self):
    method num_summary_tokens (line 287) | def num_summary_tokens(self):
    method forward_features (line 290) | def forward_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, tor...
    method forward_intermediates (line 301) | def forward_intermediates(self, x: torch.Tensor, norm: bool = False, *...
    method _patchify (line 319) | def _patchify(self, x: torch.Tensor):
  function vit_base_patch16_xpos (line 331) | def vit_base_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int =...
  function vit_large_patch16_xpos (line 337) | def vit_large_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int ...
  function vit_huge_patch16_xpos (line 343) | def vit_huge_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int =...
  function vit_giant_patch16_xpos (line 349) | def vit_giant_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int ...
  function vit_bigG_patch16_xpos (line 355) | def vit_bigG_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int =...

FILE: nit/models/nvidia_radio/radio/vit_patch_generator.py
  class ViTPatchGenerator (line 27) | class ViTPatchGenerator(nn.Module):
    method __init__ (line 28) | def __init__(self,
    method forward (line 92) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method apply_cls_token (line 102) | def apply_cls_token(self):
    method num_cls_tokens (line 106) | def num_cls_tokens(self):
    method num_cls_patches (line 110) | def num_cls_patches(self):
    method num_registers (line 114) | def num_registers(self):
    method num_skip (line 118) | def num_skip(self):
    method no_weight_decay (line 121) | def no_weight_decay(self):
    method _load_embed (line 126) | def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
    method _load_projection (line 137) | def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_we...
    method embed_patches (line 148) | def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
    method apply_pos_enc (line 153) | def apply_pos_enc(self,
    method get_pos_enc (line 171) | def get_pos_enc(self,
    method _get_pos_embeddings (line 192) | def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, ...
  class Im2Patches (line 259) | class Im2Patches(nn.Module):
    method __init__ (line 260) | def __init__(self, patch_size: int):
    method forward (line 264) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class ViTPatchLinear (line 279) | class ViTPatchLinear(nn.Linear):
    method __init__ (line 280) | def __init__(self, patch_size: int, embed_dim: int, bias: bool = False...

FILE: nit/models/nvidia_radio/radio/vitdet.py
  class VitDetArgs (line 21) | class VitDetArgs:
    method __init__ (line 22) | def __init__(self,
  function apply_vitdet_arch (line 34) | def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args...
  class ViTDetHook (line 48) | class ViTDetHook:
    method __init__ (line 49) | def __init__(self,
    method _enter_model (line 92) | def _enter_model(self, _, input: List[torch.Tensor]):
    method _enter_blocks (line 95) | def _enter_blocks(self, _, input: List[torch.Tensor]):
    method _to_windows (line 103) | def _to_windows(self, _, input: List[torch.Tensor]):
    method _to_global (line 117) | def _to_global(self, _, input: List[torch.Tensor]):
    method _exit_model (line 134) | def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Te...
    method _rearrange_patches (line 149) | def _rearrange_patches(self, patches: torch.Tensor):

FILE: nit/models/utils/convs.py
  function create_conv_1 (line 10) | def create_conv_1(conv_type, in_channels, out_channels, norm, act_func, ...
  function create_conv_2 (line 31) | def create_conv_2(conv_type, in_channels, out_channels, mid_channels):
  class DWConv (line 52) | class DWConv(nn.Module):
    method __init__ (line 53) | def __init__(
    method forward (line 76) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class DSConv (line 82) | class DSConv(nn.Module):
    method __init__ (line 83) | def __init__(
    method forward (line 118) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class DGConv (line 125) | class DGConv(nn.Module):
    method __init__ (line 126) | def __init__(
    method forward (line 162) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class MBConv (line 169) | class MBConv(nn.Module):
    method __init__ (line 170) | def __init__(
    method forward (line 217) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class FusedMBConv (line 224) | class FusedMBConv(nn.Module):
    method __init__ (line 225) | def __init__(
    method forward (line 264) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class GLUMBConv (line 270) | class GLUMBConv(nn.Module):
    method __init__ (line 271) | def __init__(
    method forward (line 318) | def forward(self, x: torch.Tensor) -> torch.Tensor:

FILE: nit/models/utils/funcs.py
  function modulate (line 6) | def modulate(x, shift, scale):
  function get_parameter_dtype (line 10) | def get_parameter_dtype(parameter: torch.nn.Module):

FILE: nit/models/utils/norms.py
  function create_norm (line 19) | def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
  class FP32_Layernorm (line 60) | class FP32_Layernorm(nn.LayerNorm):
    method forward (line 61) | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  class FusedRMSNorm (line 85) | class FusedRMSNorm(nn.Module):
    method __init__ (line 88) | def __init__(
    method forward (line 98) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method reset_parameters (line 106) | def reset_parameters(self):
  class FusedRMSNorm32 (line 109) | class FusedRMSNorm32(nn.Module):
    method __init__ (line 112) | def __init__(
    method forward (line 122) | def forward(self, x: torch.Tensor) -> torch.Tensor:
    method reset_parameters (line 131) | def reset_parameters(self):
  class RMSNorm (line 134) | class RMSNorm(nn.Module):
    method __init__ (line 135) | def __init__(self, dim: int, include_weight: bool = True, eps: float =...
    method _norm (line 156) | def _norm(self, x):
    method forward (line 169) | def forward(self, x):
  function _rms_norm_fwd_kernel (line 207) | def _rms_norm_fwd_kernel(
  function _rms_norm_bwd_kernel_sm (line 255) | def _rms_norm_bwd_kernel_sm(
  class TritonFusedRMSNorm (line 303) | class TritonFusedRMSNorm(torch.autograd.Function):
    method forward (line 305) | def forward(ctx, x, weight, eps):
    method backward (line 347) | def backward(ctx, dy):
  function fused_rms_norm_fn (line 394) | def fused_rms_norm_fn(

FILE: nit/models/utils/pos_embeds/flash_attn_rotary.py
  function rotate_half (line 11) | def rotate_half(x, interleaved=False):
  function apply_rotary_emb_torch (line 20) | def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
  class ApplyRotaryEmb (line 35) | class ApplyRotaryEmb(torch.autograd.Function):
    method forward (line 37) | def forward(
    method backward (line 70) | def backward(ctx, do):
  function apply_rotary_emb (line 94) | def apply_rotary_emb(
  class ApplyRotaryEmbQKV_ (line 131) | class ApplyRotaryEmbQKV_(torch.autograd.Function):
    method forward (line 133) | def forward(
    method backward (line 199) | def backward(ctx, dqkv):
  function apply_rotary_emb_qkv_ (line 250) | def apply_rotary_emb_qkv_(
  class ApplyRotaryEmbKV_ (line 278) | class ApplyRotaryEmbKV_(torch.autograd.Function):
    method forward (line 280) | def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Unio...
    method backward (line 297) | def backward(ctx, dkv):
  function apply_rotary_emb_kv_ (line 318) | def apply_rotary_emb_kv_(
  class RotaryEmbedding (line 341) | class RotaryEmbedding(torch.nn.Module):
    method __init__ (line 359) | def __init__(
    method _compute_inv_freq (line 404) | def _compute_inv_freq(self, device=None):
    method _update_cos_sin_cache (line 410) | def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
    method forward (line 456) | def forward(

FILE: nit/models/utils/pos_embeds/rope.py
  function find_correction_factor (line 23) | def find_correction_factor(num_rotations, dim, base=10000, max_position_...
  function find_correction_range (line 26) | def find_correction_range(low_rot, high_rot, dim, base=10000, max_positi...
  function linear_ramp_mask (line 31) | def linear_ramp_mask(min, max, dim):
  function find_newbase_ntk (line 39) | def find_newbase_ntk(dim, base=10000, scale=1):
  function get_mscale (line 43) | def get_mscale(scale=torch.Tensor):
  function get_proportion (line 49) | def get_proportion(L_test, L_train):
  function rotate_half (line 60) | def rotate_half(x):
  class VisionRotaryEmbedding (line 72) | class VisionRotaryEmbedding(nn.Module):
    method __init__ (line 73) | def __init__(
    method get_1d_rope_freqs (line 128) | def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):
    method online_get_2d_rope_from_grid (line 189) | def online_get_2d_rope_from_grid(self, grid, size):
    method get_2d_rope_from_grid (line 232) | def get_2d_rope_from_grid(self, grid):
    method get_cached_2d_rope_from_grid (line 263) | def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):
    method get_cached_21d_rope_from_grid (line 293) | def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d ...
    method forward (line 320) | def forward(self, x, grid):

FILE: nit/models/utils/pos_embeds/sincos.py
  function get_2d_sincos_pos_embed (line 13) | def get_2d_sincos_pos_embed(embed_dim, h, w, frac_coord_size=None, scale...
  function get_2d_sincos_pos_embed_from_grid (line 39) | def get_2d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=N...
  function get_1d_sincos_pos_embed_from_grid (line 68) | def get_1d_sincos_pos_embed_from_grid(pos, embed_dim):
  function get_3d_sincos_pos_embed_from_grid (line 87) | def get_3d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=N...
  function get_21d_sincos_pos_embed_from_grid (line 126) | def get_21d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=...
  function get_time_sincos_pos_embed_from_grid (line 159) | def get_time_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size...
  function interpolate_sincos_pos_embed (line 171) | def interpolate_sincos_pos_embed(embed_dim, ori_h, ori_w, tgt_h, tgt_w):
  function interpolate_sincos_pos_index (line 180) | def interpolate_sincos_pos_index(embed_dim, ori_h, ori_w, tgt_h, tgt_w):

FILE: nit/schedulers/flow_matching/loss.py
  function mean_flat (line 5) | def mean_flat(x):
  function sum_flat (line 11) | def sum_flat(x):
  class FlowMatchingLoss (line 17) | class FlowMatchingLoss:
    method __init__ (line 18) | def __init__(
    method interpolant (line 44) | def interpolant(self, t):
    method __call__ (line 65) | def __call__(self, model, batch_size, images, noises, model_kwargs=Non...

FILE: nit/schedulers/flow_matching/samplers_c2i.py
  function expand_t_like_x (line 5) | def expand_t_like_x(t, x_cur, hw_list):
  function get_score_from_velocity (line 17) | def get_score_from_velocity(vt, xt, t, hw_list, path_type="linear"):
  function compute_diffusion (line 44) | def compute_diffusion(t_cur):
  function euler_sampler (line 48) | def euler_sampler(
  function euler_maruyama_sampler (line 135) | def euler_maruyama_sampler(

FILE: nit/utils/deepspeed_zero_to_fp32.py
  class zero_model_state (line 33) | class zero_model_state:
  function atoi (line 48) | def atoi(text):
  function natural_keys (line 52) | def natural_keys(text):
  function get_model_state_file (line 61) | def get_model_state_file(checkpoint_dir, zero_stage):
  function get_checkpoint_files (line 77) | def get_checkpoint_files(checkpoint_dir, glob_pattern):
  function get_optim_files (line 87) | def get_optim_files(checkpoint_dir):
  function get_model_state_files (line 91) | def get_model_state_files(checkpoint_dir):
  function parse_model_states (line 95) | def parse_model_states(files):
  function parse_optim_states (line 141) | def parse_optim_states(files, ds_checkpoint_dir):
  function _get_fp32_state_dict_from_zero_checkpoint (line 194) | def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude...
  function _zero2_merge_frozen_params (line 221) | def _zero2_merge_frozen_params(state_dict, zero_model_states):
  function _has_callable (line 253) | def _has_callable(obj, fn):
  function _zero2_merge_trainable_params (line 258) | def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_grou...
  function _get_fp32_state_dict_from_zero2_checkpoint (line 331) | def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_gro...
  function zero3_partitioned_param_info (line 354) | def zero3_partitioned_param_info(unpartitioned_numel, world_size):
  function _zero3_merge_frozen_params (line 361) | def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
  function _zero3_merge_trainable_params (line 397) | def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_grou...
  function _get_fp32_state_dict_from_zero3_checkpoint (line 451) | def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_gro...
  function get_fp32_state_dict_from_zero_checkpoint (line 474) | def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, e...
  function convert_zero_checkpoint_to_fp32_state_dict (line 524) | def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_fi...
  function load_state_dict_from_zero_checkpoint (line 541) | def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):

FILE: nit/utils/ema.py
  function update_ema (line 8) | def update_ema(ema_model, model, decay=0.9999):

FILE: nit/utils/eval_utils.py
  function create_npz_from_sample_folder (line 11) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function init_from_ckpt (line 30) | def init_from_ckpt(
  function none_or_str (line 60) | def none_or_str(value):
  function parse_sde_args (line 65) | def parse_sde_args(parser):
  function parse_ode_args (line 77) | def parse_ode_args(parser):

FILE: nit/utils/freeze.py
  function freeze_model (line 5) | def freeze_model(model, trainable_modules={}, verbose=False):

FILE: nit/utils/gpu_memory_monitor.py
  class GPUMemoryMonitor (line 28) | class GPUMemoryMonitor:
    method __init__ (line 29) | def __init__(self, logger, device: str = "cuda:0"):
    method _to_gib (line 43) | def _to_gib(self, memory_in_bytes):
    method _to_pct (line 49) | def _to_pct(self, memory):
    method get_peak_stats (line 52) | def get_peak_stats(self):
    method reset_peak_stats (line 80) | def reset_peak_stats(self):
  function build_gpu_memory_monitor (line 84) | def build_gpu_memory_monitor(logger):

FILE: nit/utils/lr_scheduler.py
  class SchedulerType (line 29) | class SchedulerType(Enum):
  function get_constant_schedule (line 40) | def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
  function get_constant_schedule_with_warmup (line 55) | def get_constant_schedule_with_warmup(
  function get_piecewise_constant_schedule (line 67) | def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: st...
  function get_linear_schedule_with_warmup (line 109) | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_tra...
  function get_cosine_schedule_with_warmup (line 138) | def get_cosine_schedule_with_warmup(
  function get_cosine_with_hard_restarts_schedule_with_warmup (line 172) | def get_cosine_with_hard_restarts_schedule_with_warmup(
  function get_polynomial_decay_schedule_with_warmup (line 207) | def get_polynomial_decay_schedule_with_warmup(
  function get_constant_schedule_with_warmup_and_decay (line 257) | def get_constant_schedule_with_warmup_and_decay(
  function get_scheduler (line 288) | def get_scheduler(

FILE: nit/utils/misc_utils.py
  function get_dtype (line 16) | def get_dtype(str_dtype):
  function disabled_train (line 25) | def disabled_train(self, mode=True):
  function get_string_from_tuple (line 31) | def get_string_from_tuple(s):
  function is_power_of_two (line 47) | def is_power_of_two(n):
  function autocast (line 63) | def autocast(f, enabled=True):
  function load_partial_from_config (line 75) | def load_partial_from_config(config):
  function log_txt_as_img (line 79) | def log_txt_as_img(wh, xc, size=10):
  function partialclass (line 109) | def partialclass(cls, *args, **kwargs):
  function make_path_absolute (line 116) | def make_path_absolute(path):
  function ismap (line 123) | def ismap(x):
  function isimage (line 129) | def isimage(x):
  function isheatmap (line 135) | def isheatmap(x):
  function isneighbors (line 142) | def isneighbors(x):
  function exists (line 148) | def exists(x):
  function expand_dims_like (line 152) | def expand_dims_like(x, y):
  function default (line 158) | def default(val, d):
  function mean_flat (line 164) | def mean_flat(tensor):
  function count_params (line 172) | def count_params(model, verbose=False):
  function instantiate_from_config (line 179) | def instantiate_from_config(config):
  function get_obj_from_str (line 189) | def get_obj_from_str(string, reload=False, invalidate_cache=True):
  function append_zero (line 199) | def append_zero(x):
  function append_dims (line 203) | def append_dims(x, target_dims):
  function load_model_from_config (line 213) | def load_model_from_config(config, ckpt, verbose=True, freeze=True):
  function format_number (line 246) | def format_number(num):
  function get_num_params (line 251) | def get_num_params(model: torch.nn.ModuleList) -> int:
  function get_num_flop_per_token (line 256) | def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
  function get_num_flop_per_sequence_encoder_only (line 273) | def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq...
  function get_peak_flops (line 289) | def get_peak_flops(device_name: str) -> int:
  class Color (line 306) | class Color:
  class NoColor (line 319) | class NoColor:

FILE: nit/utils/model_utils.py
  function dc_ae_encode (line 9) | def dc_ae_encode(dc_ae, images):
  function dc_ae_decode (line 14) | def dc_ae_decode(dc_ae, latents):
  function sd_vae_encode (line 26) | def sd_vae_encode(sd_vae, images):
  function sd_vae_decode (line 34) | def sd_vae_decode(sd_vae, latents):
  function load_text_encoder (line 46) | def load_text_encoder(text_encoder_dir, device, weight_dtype):
  function encode_prompt (line 66) | def encode_prompt(tokenizer, text_encoder, device, weight_dtype, caption...
  function prepare_null_cap_feat_mask (line 91) | def prepare_null_cap_feat_mask(text_encoder_type, device, weight_dtype, ...

FILE: nit/utils/train_utils.py
  function freeze_model (line 8) | def freeze_model(model, trainable_modules={}, verbose=False):
  function update_ema (line 29) | def update_ema(ema_model, model, decay=0.9999):
  function log_validation (line 46) | def log_validation(model):

FILE: nit/utils/util.py
  function disabled_train (line 16) | def disabled_train(self, mode=True):
  function get_string_from_tuple (line 22) | def get_string_from_tuple(s):
  function is_power_of_two (line 38) | def is_power_of_two(n):
  function autocast (line 54) | def autocast(f, enabled=True):
  function load_partial_from_config (line 66) | def load_partial_from_config(config):
  function log_txt_as_img (line 70) | def log_txt_as_img(wh, xc, size=10):
  function partialclass (line 100) | def partialclass(cls, *args, **kwargs):
  function make_path_absolute (line 107) | def make_path_absolute(path):
  function ismap (line 114) | def ismap(x):
  function isimage (line 120) | def isimage(x):
  function isheatmap (line 126) | def isheatmap(x):
  function isneighbors (line 133) | def isneighbors(x):
  function exists (line 139) | def exists(x):
  function expand_dims_like (line 143) | def expand_dims_like(x, y):
  function default (line 149) | def default(val, d):
  function mean_flat (line 155) | def mean_flat(tensor):
  function count_params (line 163) | def count_params(model, verbose=False):
  function instantiate_from_config (line 170) | def instantiate_from_config(config):
  function get_obj_from_str (line 180) | def get_obj_from_str(string, reload=False, invalidate_cache=True):
  function append_zero (line 190) | def append_zero(x):
  function append_dims (line 194) | def append_dims(x, target_dims):
  function load_model_from_config (line 204) | def load_model_from_config(config, ckpt, verbose=True, freeze=True):
  function format_number (line 237) | def format_number(num):
  function get_num_params (line 242) | def get_num_params(model: torch.nn.ModuleList) -> int:
  function get_num_flop_per_token (line 247) | def get_num_flop_per_token(num_params, model_config, seq_len) -> int:
  function get_num_flop_per_sequence_encoder_only (line 264) | def get_num_flop_per_sequence_encoder_only(num_params, model_config, seq...
  function get_peak_flops (line 280) | def get_peak_flops(device_name: str) -> int:
  class Color (line 297) | class Color:
  class NoColor (line 310) | class NoColor:

FILE: nit/utils/video_utils.py
  function save_video_as_mp4 (line 6) | def save_video_as_mp4(video_array, fps, output_path):
  function save_video_as_png (line 17) | def save_video_as_png(video_array, output_path):

FILE: nit/utils/warp_pos_idx.py
  function warp_pos_idx_from_grid (line 6) | def warp_pos_idx_from_grid(
  function warp_pos_idx (line 24) | def warp_pos_idx(

FILE: projects/evaluate/adm_evaluator.py
  function main (line 28) | def main():
  class InvalidFIDException (line 64) | class InvalidFIDException(Exception):
  class FIDStatistics (line 68) | class FIDStatistics:
    method __init__ (line 69) | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
    method frechet_distance (line 73) | def frechet_distance(self, other, eps=1e-6):
  class Evaluator (line 119) | class Evaluator:
    method __init__ (line 120) | def __init__(
    method warmup (line 136) | def warmup(self):
    method read_activations (line 139) | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndar...
    method compute_activations (line 169) | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[...
    method read_statistics (line 191) | def read_statistics(
    method compute_statistics (line 202) | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
    method compute_inception_score (line 207) | def compute_inception_score(self, activations: np.ndarray, split_size:...
    method compute_prec_recall (line 222) | def compute_prec_recall(
  class ManifoldEstimator (line 233) | class ManifoldEstimator:
    method __init__ (line 240) | def __init__(
    method warmup (line 269) | def warmup(self):
    method manifold_radii (line 276) | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
    method evaluate (line 311) | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_featu...
    method evaluate_pr (line 353) | def evaluate_pr(
  class DistanceBlock (line 390) | class DistanceBlock:
    method __init__ (line 397) | def __init__(self, session):
    method pairwise_distances (line 421) | def pairwise_distances(self, U, V):
    method less_thans (line 430) | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
  function _batch_pairwise_distances (line 442) | def _batch_pairwise_distances(U, V):
  class NpzArrayReader (line 461) | class NpzArrayReader(ABC):
    method read_batch (line 463) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 467) | def remaining(self) -> int:
    method read_batches (line 470) | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
  class BatchIterator (line 483) | class BatchIterator:
    method __init__ (line 484) | def __init__(self, gen_fn, length):
    method __len__ (line 488) | def __len__(self):
    method __iter__ (line 491) | def __iter__(self):
  class StreamingNpzArrayReader (line 495) | class StreamingNpzArrayReader(NpzArrayReader):
    method __init__ (line 496) | def __init__(self, arr_f, shape, dtype):
    method read_batch (line 502) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 517) | def remaining(self) -> int:
  class MemoryNpzArrayReader (line 521) | class MemoryNpzArrayReader(NpzArrayReader):
    method __init__ (line 522) | def __init__(self, arr):
    method load (line 527) | def load(cls, path: str, arr_name: str):
    method read_batch (line 532) | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
    method remaining (line 540) | def remaining(self) -> int:
  function open_npz_array (line 545) | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
  function _read_bytes (line 562) | def _read_bytes(fp, size, error_template="ran out of data"):
  function _open_npy_file (line 592) | def _open_npy_file(path: str, arr_name: str):
  function _download_inception_model (line 601) | def _download_inception_model():
  function _create_feature_graph (line 614) | def _create_feature_graph(input_batch):
  function _create_softmax_graph (line 631) | def _create_softmax_graph(input_batch):
  function _update_shapes (line 645) | def _update_shapes(pool3):
  function _numpy_partition (line 664) | def _numpy_partition(arr, kth, **kwargs):

FILE: projects/preprocess/image_latent_c2i.py
  function resolve_tuple (line 42) | def resolve_tuple(*args):
  function parse_args (line 47) | def parse_args():
  function center_crop_arr (line 73) | def center_crop_arr(pil_image, image_size):
  class ImageFolder (line 95) | class ImageFolder(DatasetFolder):
    method __init__ (line 125) | def __init__(
    method __getitem__ (line 143) | def __getitem__(self, index: int) -> Tuple[Any, Any]:
  class ImagenetDataDictWrapper (line 160) | class ImagenetDataDictWrapper(Dataset):
    method __init__ (line 161) | def __init__(self, dataset):
    method __getitem__ (line 165) | def __getitem__(self, i):
    method __len__ (line 169) | def __len__(self):
  function get_train_sampler (line 173) | def get_train_sampler(global_batch_size, max_steps, resume_step):
  class ImagenetLoader (line 178) | class ImagenetLoader():
    method __init__ (line 179) | def __init__(self, data_config):
    method train_len (line 196) | def train_len(self):
    method train_dataloader (line 199) | def train_dataloader(self, global_batch_size, max_steps, resume_step):
    method test_dataloader (line 212) | def test_dataloader(self):
    method val_dataloader (line 215) | def val_dataloader(self):
  function main (line 224) | def main(args):

FILE: projects/preprocess/image_nr_latent_c2i.py
  function resolve_tuple (line 42) | def resolve_tuple(*args):
  function parse_args (line 47) | def parse_args():
  function native_resolution_resize (line 73) | def native_resolution_resize(pil_image, min_image_size, max_image_size):
  class ImageFolder (line 92) | class ImageFolder(DatasetFolder):
    method __init__ (line 122) | def __init__(
    method __getitem__ (line 140) | def __getitem__(self, index: int) -> Tuple[Any, Any]:
  class ImagenetDataDictWrapper (line 157) | class ImagenetDataDictWrapper(Dataset):
    method __init__ (line 158) | def __init__(self, dataset):
    method __getitem__ (line 162) | def __getitem__(self, i):
    method __len__ (line 166) | def __len__(self):
  function get_train_sampler (line 170) | def get_train_sampler(global_batch_size, max_steps, resume_step):
  class ImagenetLoader (line 175) | class ImagenetLoader():
    method __init__ (line 176) | def __init__(self, data_config):
    method train_len (line 195) | def train_len(self):
    method train_dataloader (line 198) | def train_dataloader(self, global_batch_size, max_steps, resume_step):
    method test_dataloader (line 211) | def test_dataloader(self):
    method val_dataloader (line 214) | def val_dataloader(self):
  function main (line 223) | def main(args):

FILE: projects/sample/sample_c2i_ddp.py
  function create_npz_from_sample_folder (line 32) | def create_npz_from_sample_folder(sample_dir, num=50_000):
  function main (line 49) | def main(args):

FILE: projects/train/packed_trainer_c2i.py
  function parse_args (line 82) | def parse_args():
  function main (line 108) | def main(args):

FILE: tools/pack_dataset.py
  function create_pack (line 7) | def create_pack(data_meta, max_seq_len, algorithm, split):
Condensed preview — 91 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (604K chars).
[
  {
    "path": ".gitignore",
    "chars": 3511,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11356,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 11977,
    "preview": "<h1 align=\"center\"> Native-Resolution Image Synthesis</h1>\n\n<!-- \n[![PWC](https://img.shields.io/endpoint.svg?url=https:"
  },
  {
    "path": "configs/c2i/nit_b_pack_merge_radio_65536.yaml",
    "chars": 2169,
    "preview": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models."
  },
  {
    "path": "configs/c2i/nit_l_pack_merge_radio_16384.yaml",
    "chars": 2170,
    "preview": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models."
  },
  {
    "path": "configs/c2i/nit_s_pack_merge_radio_65536.yaml",
    "chars": 2197,
    "preview": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models."
  },
  {
    "path": "configs/c2i/nit_xl_pack_merge_radio_16384.yaml",
    "chars": 2170,
    "preview": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models."
  },
  {
    "path": "configs/c2i/nit_xxl_pack_merge_radio_8192.yaml",
    "chars": 2214,
    "preview": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models."
  },
  {
    "path": "configs/preprocess/imagenet1k_256x256.yaml",
    "chars": 1135,
    "preview": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/tr"
  },
  {
    "path": "configs/preprocess/imagenet1k_512x512.yaml",
    "chars": 1135,
    "preview": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/tr"
  },
  {
    "path": "configs/preprocess/imagenet1k_native_resolution.yaml",
    "chars": 1172,
    "preview": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/tr"
  },
  {
    "path": "nit/data/pack/__init__.py",
    "chars": 2022,
    "preview": "from .ennlshp import ENNLSHP\nfrom .lpfhp import LPFHP\nfrom .nnlshp import NNLSHP\nfrom .spfhp import SPFHP\n\nimport json\ni"
  },
  {
    "path": "nit/data/pack/ennlshp.py",
    "chars": 4899,
    "preview": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2"
  },
  {
    "path": "nit/data/pack/lpfhp.py",
    "chars": 6133,
    "preview": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2"
  },
  {
    "path": "nit/data/pack/nnlshp.py",
    "chars": 4011,
    "preview": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2"
  },
  {
    "path": "nit/data/pack/spfhp.py",
    "chars": 4771,
    "preview": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2"
  },
  {
    "path": "nit/data/packed_c2i_data.py",
    "chars": 6326,
    "preview": "import os\nimport datetime\nimport torchvision\nimport numpy as np\nimport torch\nimport ast\nimport json\nimport time\n\n\nfrom o"
  },
  {
    "path": "nit/data/sampler_util.py",
    "chars": 2097,
    "preview": "import torch\nimport json\n\n# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L6"
  },
  {
    "path": "nit/models/c2i/nit_model.py",
    "chars": 14753,
    "preview": "# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n"
  },
  {
    "path": "nit/models/nvidia_radio/hubconf.py",
    "chars": 7563,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/__init__.py",
    "chars": 770,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_base.py",
    "chars": 1207,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_generic.py",
    "chars": 3306,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_mlp.py",
    "chars": 6017,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_registry.py",
    "chars": 1371,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/block.py",
    "chars": 1974,
    "preview": "# Ultralytics YOLO 🚀, AGPL-3.0 license\n\"\"\"\nBlock modules\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers"
  },
  {
    "path": "nit/models/nvidia_radio/radio/cls_token.py",
    "chars": 1787,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/common.py",
    "chars": 3869,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/conv.py",
    "chars": 2411,
    "preview": "# Ultralytics YOLO 🚀, AGPL-3.0 license\n\"\"\"\nConvolution modules\n\"\"\"\n\nimport math\n\nimport numpy as np\nimport torch\nimport "
  },
  {
    "path": "nit/models/nvidia_radio/radio/dinov2_arch.py",
    "chars": 35529,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version "
  },
  {
    "path": "nit/models/nvidia_radio/radio/dual_hybrid_vit.py",
    "chars": 7179,
    "preview": "from logging import getLogger\nfrom typing import Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functiona"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_cpe_support.py",
    "chars": 8154,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_damp.py",
    "chars": 1290,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_spectral_reparam.py",
    "chars": 11004,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/eradio_model.py",
    "chars": 58123,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its l"
  },
  {
    "path": "nit/models/nvidia_radio/radio/extra_models.py",
    "chars": 5691,
    "preview": "from distutils.version import LooseVersion\nfrom types import MethodType\nfrom typing import List, Optional, Tuple, Union\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/extra_timm_models.py",
    "chars": 8205,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/feature_normalizer.py",
    "chars": 4399,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/forward_intermediates.py",
    "chars": 5604,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/hf_model.py",
    "chars": 7684,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 "
  },
  {
    "path": "nit/models/nvidia_radio/radio/input_conditioner.py",
    "chars": 1492,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/open_clip_adaptor.py",
    "chars": 1550,
    "preview": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all inte"
  },
  {
    "path": "nit/models/nvidia_radio/radio/radio_model.py",
    "chars": 15060,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vision_transformer_xpos.py",
    "chars": 13097,
    "preview": "import math\nfrom typing import Final, List, Optional, Tuple, Union\n\n\nfrom einops import rearrange\nfrom timm.models impor"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vit_patch_generator.py",
    "chars": 10961,
    "preview": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vitdet.py",
    "chars": 6318,
    "preview": "from collections import defaultdict\nfrom contextlib import contextmanager\nfrom logging import getLogger\nimport math\nimpo"
  },
  {
    "path": "nit/models/utils/convs.py",
    "chars": 8887,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom nit.models.efficientvit.models.nn.ops import C"
  },
  {
    "path": "nit/models/utils/funcs.py",
    "chars": 917,
    "preview": "import torch\nfrom torch import Tensor\nfrom typing import List, Tuple\nfrom itertools import chain\n\ndef modulate(x, shift,"
  },
  {
    "path": "nit/models/utils/norms.py",
    "chars": 11311,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "nit/models/utils/pos_embeds/flash_attn_rotary.py",
    "chars": 20225,
    "preview": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom einops import r"
  },
  {
    "path": "nit/models/utils/pos_embeds/rope.py",
    "chars": 15938,
    "preview": "# --------------------------------------------------------\n# FiT: A Flexible Vision Transformer for Image Generation\n#\n#"
  },
  {
    "path": "nit/models/utils/pos_embeds/sincos.py",
    "chars": 7914,
    "preview": "#################################################################################\n#                   Sine/Cosine Positi"
  },
  {
    "path": "nit/schedulers/flow_matching/loss.py",
    "chars": 4450,
    "preview": "import torch\nimport numpy as np\nimport torch.nn.functional as F\n\ndef mean_flat(x):\n    \"\"\"\n    Take the mean over all no"
  },
  {
    "path": "nit/schedulers/flow_matching/samplers_c2i.py",
    "chars": 10737,
    "preview": "import torch\nimport numpy as np\n\n\ndef expand_t_like_x(t, x_cur, hw_list):\n    \"\"\"Function to reshape time t to broadcast"
  },
  {
    "path": "nit/utils/__init__.py",
    "chars": 112,
    "preview": "from .misc_utils import *\nfrom .train_utils import *\nfrom .eval_utils import *\nfrom .gpu_memory_monitor import *"
  },
  {
    "path": "nit/utils/deepspeed_zero_to_fp32.py",
    "chars": 25314,
    "preview": "#!/usr/bin/env python\n\n# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\n#"
  },
  {
    "path": "nit/utils/ema.py",
    "chars": 699,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\n\n\n\n@torch.no_grad()\ndef update_ema(ema_model,"
  },
  {
    "path": "nit/utils/eval_utils.py",
    "chars": 3865,
    "preview": "from PIL import Image\nimport numpy as np\nfrom tqdm import tqdm\nimport torch\nimport re\nimport os\n\nfrom safetensors.torch "
  },
  {
    "path": "nit/utils/freeze.py",
    "chars": 1022,
    "preview": "from diffusers.utils import logging\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\ndef freeze_m"
  },
  {
    "path": "nit/utils/gpu_memory_monitor.py",
    "chars": 2793,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the B"
  },
  {
    "path": "nit/utils/lr_scheduler.py",
    "chars": 15508,
    "preview": "from torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\n# coding=utf-8\n# Copyright 2023 The Hu"
  },
  {
    "path": "nit/utils/misc_utils.py",
    "chars": 9717,
    "preview": "import functools\nimport importlib\nimport os\nimport wandb\nimport fsspec\nimport numpy as np\nimport torch\n\nfrom dataclasses"
  },
  {
    "path": "nit/utils/model_utils.py",
    "chars": 3350,
    "preview": "import os\nimport torch\nfrom transformers import T5EncoderModel, AutoModelForCausalLM, AutoTokenizer\n\n\n\n\n# dc-ae\ndef dc_a"
  },
  {
    "path": "nit/utils/train_utils.py",
    "chars": 1755,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom diffusers.utils import logging\n\nlogger ="
  },
  {
    "path": "nit/utils/util.py",
    "chars": 9525,
    "preview": "import functools\nimport importlib\nimport os\nimport wandb\nimport fsspec\nimport numpy as np\nimport torch\n\nfrom dataclasses"
  },
  {
    "path": "nit/utils/video_utils.py",
    "chars": 872,
    "preview": "import os\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\ndef save_video_as_mp4(video_array, fps, output_path):\n   "
  },
  {
    "path": "nit/utils/warp_pos_idx.py",
    "chars": 1625,
    "preview": "import torch\nimport random\nfrom typing import Optional, Union\n\n\ndef warp_pos_idx_from_grid(\n    grid: torch.Tensor, \n   "
  },
  {
    "path": "projects/evaluate/adm_evaluator.py",
    "chars": 26384,
    "preview": "import argparse\nimport io\nimport os\nimport random\nimport warnings\nimport zipfile\nfrom abc import ABC, abstractmethod\nfro"
  },
  {
    "path": "projects/preprocess/image_latent_c2i.py",
    "chars": 16792,
    "preview": "import os\nimport torch\nimport argparse\nimport datetime\nimport time\nimport torchvision\nimport logging\nimport math\nimport "
  },
  {
    "path": "projects/preprocess/image_nr_latent_c2i.py",
    "chars": 16838,
    "preview": "import os\nimport torch\nimport argparse\nimport datetime\nimport time\nimport torchvision\nimport logging\nimport math\nimport "
  },
  {
    "path": "projects/sample/sample_c2i_ddp.py",
    "chars": 11203,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "projects/train/packed_trainer_c2i.py",
    "chars": 21452,
    "preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under"
  },
  {
    "path": "requirements.txt",
    "chars": 571,
    "preview": "diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested\ntransformers>=4.44.2"
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_256x256.sh",
    "chars": 509,
    "preview": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/p"
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_512x512.sh",
    "chars": 509,
    "preview": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/p"
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_native_resolution.sh",
    "chars": 532,
    "preview": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/p"
  },
  {
    "path": "scripts/sample/sample_256x256.sh",
    "chars": 408,
    "preview": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pac"
  },
  {
    "path": "scripts/sample/sample_512x512.sh",
    "chars": 408,
    "preview": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pac"
  },
  {
    "path": "scripts/sample/sample_768x768.sh",
    "chars": 406,
    "preview": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pac"
  },
  {
    "path": "scripts/train/train_b_model.sh",
    "chars": 538,
    "preview": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_b_pack_merge_radio_65"
  },
  {
    "path": "scripts/train/train_l_model.sh",
    "chars": 538,
    "preview": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_l_pack_merge_radio_16"
  },
  {
    "path": "scripts/train/train_s_model.sh",
    "chars": 538,
    "preview": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_s_pack_merge_radio_65"
  },
  {
    "path": "scripts/train/train_xl_model.sh",
    "chars": 541,
    "preview": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_xl_pack_merge_radio_1"
  },
  {
    "path": "scripts/train/train_xxl_model.sh",
    "chars": 541,
    "preview": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_xxl_pack_merge_radio_"
  },
  {
    "path": "setup.py",
    "chars": 204,
    "preview": "from setuptools import find_packages, setup\n\nsetup(\n    name='nit',\n    version='0.0.1',\n    description='',\n    package"
  },
  {
    "path": "tools/download_dataset_256x256.sh",
    "chars": 741,
    "preview": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256\"\nmkdir -p $target_dir\nbase_url=\"https://huggingf"
  },
  {
    "path": "tools/download_dataset_512x512.sh",
    "chars": 1191,
    "preview": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512\"\nmkdir -p $target_dir\nbase_url=\"https://huggingf"
  },
  {
    "path": "tools/download_dataset_data_meta.sh",
    "chars": 610,
    "preview": "target_dir=\"datasets/imagenet1k/data_meta\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT"
  },
  {
    "path": "tools/download_dataset_native_resolution.sh",
    "chars": 911,
    "preview": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution\"\nmkdir -p $target_dir\nbase_url=\"https:"
  },
  {
    "path": "tools/download_dataset_sampler_meta.sh",
    "chars": 641,
    "preview": "target_dir=\"datasets/imagenet1k/sampler_meta\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/"
  },
  {
    "path": "tools/pack_dataset.py",
    "chars": 1836,
    "preview": "import json\nfrom nit.data.pack import pack_dataset\nimport argparse\n\n\n\ndef create_pack(data_meta, max_seq_len, algorithm,"
  }
]

About this extraction

This page contains the full source code of the WZDTHU/NiT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 91 files (565.6 KB), approximately 145.9k tokens, and a symbol index with 746 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!