[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# UV\n#   Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#uv.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# Ruff stuff:\n.ruff_cache/\n\n# PyPI configuration file\n.pypirc\n\n*.json\n*.png \n*.jpg\n/checkpoints\n/workdir\n/datasets\n/wandb\n/samples"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\"> Native-Resolution Image Synthesis</h1>\n\n<!-- \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/representation-alignment-for-generation/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=representation-alignment-for-generation) -->\n\n\n\n<div align=\"center\">\n  <a href=\"https://github.com/WZDTHU\" target=\"_blank\">ZiDong&nbsp;Wang</a><sup>1,2</sup> \n  &ensp; <b>&middot;</b> &ensp;\n  <a href=\"http://leibai.site\" target=\"_blank\">Lei&nbsp;Bai</a><sup>2,*</sup> \n  &ensp; <b>&middot;</b> &ensp;\n  <a href=\"https://xyue.io\" target=\"_blank\">Xiangyu&nbsp;Yue</a><sup>1</sup> \n  &ensp; <b>&middot;</b> &ensp;\n  <a href=\"https://wlouyang.github.io\" target=\"_blank\">Wanli&nbsp;Ouyang</a><sup>1,2</sup>\n  &ensp; <b>&middot;</b> &ensp;\n  <a href=\"https://invictus717.github.io\" target=\"_blank\">Yiyuan&nbsp;Zhang</a><sup>1,2,*</sup> </b>\n  \n  <sup>1</sup> MMLab CUHK &emsp; <sup>2</sup>Shanghai AI Lab <br>\n  <sup>*</sup>Correspondance &emsp; <br>\n</div>\n<h3 align=\"center\">\n[<a href=\"https://wzdthu.github.io/NiT\">project page</a>]&emsp;\n[<a href=\"https://arxiv.org/abs/2506.03131\">arXiv</a>]&emsp;\n[<a href=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K\">Dataset</a>]&emsp;\n[<a href=\"https://huggingface.co/GoodEnough/NiT-Models\">Model</a>]&emsp;\n\n</h3>\n<br>\n\n\n<b>Summary</b>: We propose Native-resolution diffusion Transformer (NiT), a model that explicitly learns varing resolutions and aspect ratios within its denoising process. This significantly improves training efficiency and generalization capability. To the best of our knowledge, <b>NiT firstly attains SOTA results on both</b> $256\\times256$ ($2.08$ <b>FID</b>) <b>and</b> $512\\times512$ ($1.48$ <b>FID</b>) <b>benchmarks in class-guided ImageNet generation</b>. NiT can also generalizes to arbitrary resolutions and aspect ratios, such as $4.52$ FID on $1024\\times1024$ resolution, $4.11$ FID on $432\\times768$ resolution.\n\n\n![Figure](./assets/teaser.png)\n\n### 🚨 News\n\n\n- `2025-9-18` NiT is accepted by NeurIPS 2025! 🍺\n\n- `2025-6-3` We are delighted to introduce NiT, which is the first work to explicitly model native resolution image synthesis. We have released the code, pretrained models, and processed dataset of NiT.\n\n\n\n### 1. Setup\n\n\nFirst, clone the repo:\n```bash\ngit clone https://github.com/WZDTHU/NiT.git && cd NiT\n```\n\n#### 1.1 Environment Setup\n\n```bash\nconda create -n nit_env python=3.10\npip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118\npip install flash-attn\npip install -r requirements.txt\npip install -e .\n```\n\n\n#### 1.2 Model Zoo (WIP)\n\nWith a single model, NiT-XL can compete on multiple benchmarks and it achieves a dual SOTA on both ImageNet-$256\\times256$ and $512\\times512$ benchmarks.\n\n| Model | Model Zoo | Model Size | FID-256x256 | FID-512x512 | FID-768x768 | FID-1024x1024 |\n|---------------|------------|---------|------------|------------|------------|------------|\n| NiT-XL-1000K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors) | 675M | 2.16 | 1.57 | 4.05 | 4.52 |\n| NiT-XL-1500K | [🤗 HF](https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors) | 675M | 2.03 | 1.45 | - | - |\n\n\n```bash\nmkdir checkpoints\nwget -c \"https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1000K.safetensors\" -O checkpoints/nit_xl_model_1000K.safetensors\nwget -c \"https://huggingface.co/GoodEnough/NiT-XL-Models/resolve/main/model_1500K.safetensors\" -O checkpoints/nit_xl_model_1500K.safetensors\n```\n\n\n### 2. Sampling \n\n#### 2.1 Sampling Hyper-parameters\n\nThe sampling hyper-parameters for NiT-XL-1000K are summarized as follows:\n| Resolution | Solver | NFE | CFG - scale | CFG - interval | FID | sFID | IS | Prec. | Rec. |\n| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n| 256 × 256 | SDE | 250 | 2.25 | [0.0, 0.7] | 2.16 | 6.34 | 253.44 | 0.79 | 0.62 |\n| 512 × 512 | SDE | 250 |  2.05 | [0.0, 0.7] | 1.57 | 4.13 | 260.69 | 0.81 | 0.63 |\n| 768 × 768 | ODE | 50 | 3.0 | [0.0, 0.7] | 4.05 | 8.77 | 262.31 | 0.83 | 0.52 |\n| 1024 × 1024 | ODE | 50 |  3.0 | [0.0, 0.8] | 4.52 | 7.99 | 286.87 | 0.82 | 0.50 |\n| 1536 × 1536 | ODE | 50 |  3.5 | [0.0, 0.9] | 6.51 | 9.97 | 230.10 | 0.83 | 0.42 |\n| 2048 × 2048 | ODE | 50 |  4.5 | [0.0, 0.9] | 24.76 | 18.02 | 131.36 | 0.67 | 0.46 |\n| 320 × 960 | ODE | 50 |  4.0 | [0.0, 0.9] | 16.85 | 17.79 | 189.18 | 0.71 | 0.38 |\n| 432 × 768 | ODE | 50 |  2.75 | [0.0, 0.7] | 4.11 | 10.30 | 254.71 | 0.83 | 0.55 |\n| 480 × 640 | ODE | 50 |  2.75 | [0.0, 0.7] | 3.72 | 8.23 | 284.94 | 0.83 | 0.54 |\n| 640 × 480 | ODE | 50 |  2.5 | [0.0, 0.7] | 3.41 | 8.07 | 259.06 | 0.83 | 0.56 |\n| 768 × 432 | ODE | 50 |  2.85 | [0.0, 0.7] | 5.27 | 9.92 | 218.78 | 0.80 | 0.55 |\n| 960 × 320 | ODE | 50 |  4.5 | [0.0, 0.9] | 9.90 | 25.78 | 255.95 | 0.74 | 0.40 |\n\n#### 2.2 Sampling Scripts\n\nSampling with NiT-XL-1000K model for $256\\times256$-resolution images: \n```bash\nbash scripts/sample/sample_256x256.sh\n```\n\nSampling with NiT-XL-1000K model for $512\\times512$-resolution images: \n```bash\nbash scripts/sample/sample_512x512.sh\n```\n\nSampling with NiT-XL-1000K model for $768\\times768$-resolution images: \n```bash\nbash scripts/sample/sample_768x768.sh\n```\n\n### 3. Evaluation\n\nThe sampling generates a folder of samples to compute FID, Inception Score and\nother metrics. \n<b>Note that we do not pack the generate samples as a `.npz` file, this does not affect the calculation of FID and other metrics.</b>\nPlease follow the [ADM's TensorFlow\nevaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations)\nto setup the conda-environment and download the reference batch. \n\n```bash\nwget -c \"https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb\" -O checkpoints/classify_image_graph_def.pb\n```\n\n\nGiven the directory of the reference batch `REFERENCE_DIR` and the directory of the generated images `SAMPLING_DIR`, run the following codes:\n```bash\npython projects/evaluate/adm_evaluator.py $REFERENCE_DIR $SAMPLING_DIR\n```\n\n\n\n### 4. Training\n\n#### 4.1 Dataset Setup\n\nCurrently, we provide all the [preprocessed dataset](https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K) for ImageNet1K. Please use the following commands to download the meta files and preprocessed latents.\n\n```bash\nmkdir datasets\nmkdir datasets/imagenet1k\n\nbash tools/download_dataset_256x256.sh\nbash tools/download_dataset_512x512.sh\nbash tools/download_dataset_native_resolution.sh\n```\n\n####  Preprocess ImageNet1K Locally\n\nYou can also preprocess the ImageNet1K dataset on your own. \nTake $256\\times256$-image preprocess as example, you should first modify the `data_dir` as your local ImageNet1K directory in `configs/preprocess/imagenet1k_256x256.yaml`. \nThen run the preprocess script `scripts/preprocess/preorocess_in1k_256x256.sh`.\n```bash\nbash scripts/preprocess/preorocess_in1k_256x256.sh\n```\n\nThe proprecessing procedure of $512\\times512$-image and native-resolution image is similiar. \nModify the corresponding config file and run the script.\n```bash\nbash scripts/preprocess/preorocess_in1k_512x512.sh\nbash scripts/preprocess/preorocess_in1k_native_resolution.sh\n```\n\n\n#### 4.2 Packing \n\nAs we pack multiple image instances with distinct resolution into one sequence, we need to pre-set the image indices of each pack before the training process. \n\n#### Download meta-info files\nDown all the data-meta files firstly, which restore the height, width and other information of each image.\n```bash\nbash tools/download_dataset_data_meta.sh\n```\nThe above command will download four the data-meta files on `datasets/imagenet1k/data_meta` directory:\n\n- `dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl`: data-meta file for $256\\times256$-resolution image data.\n- `dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl`, data-meta file for $512\\times512$-resolution image data.\n- `dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`, data-meta file for native-resolution image data.\n- `dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl`, a merged file of the above three files.\n\nThe first two items of the native-resolution-image data-meta file (`dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl`) are as follows:\n```json\n{\"image_file\": \"n01601694/n01601694_11629.JPEG\", \"latent_file\": \"n01601694/n01601694_11629.safetensors\", \"ori_w\": 580, \"ori_h\": 403, \"latent_h\": 12, \"latent_w\": 18, \"image_h\": 384, \"image_w\": 576, \"type\": \"native-resolution\"}\n\n{\"image_file\": \"n01601694/n01601694_11799.JPEG\", \"latent_file\": \"n01601694/n01601694_11799.safetensors\", \"ori_w\": 500, \"ori_h\": 350, \"latent_h\": 10, \"latent_w\": 15, \"image_h\": 320, \"image_w\": 480, \"type\": \"native-resolution\"}\n```\n\n#### Sampler-Meta Download\n\nGiven the maximum length $L$, we pre-set the image indices of each pack before training. \nHere we use the LPFHP (longest-pack-first histogram packing) algorithm to pack all the dataset.\n\nYou can download our preprocessed packed sampler-meta file using the following command.\n```bash\nbash tools/download_dataset_sampler_meta.sh\n```\nThe above command will download three the data-meta files on `datasets/imagenet1k/sampler_meta` directory:\n- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json`: corresponds to $L=16384$.\n- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json`: corresponds to $L=16384$. This is the setting in NiT-XL experiments.\n- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json`, corresponds to $L=32768$.\n- `dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json`, corresponds to $L=65536$.\n\n\n#### Prepare the Packing (Sampler-Meta) on Your Own\n\nNiT supports training with images of arbitrary resolutions and aspect ratios, you can also prepare the packing (sampler-meta) according to your own demands.\n\n```bash\n# generate the default sampler-meta\npython tools/pack_dataset.py\n# generate the sampelr-meta for fixed 256x256-resolution experiment with the maximum sequence length of 16384\npython tools/pack_dataset.py --data-meta datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl --max-seq-len 16384\n```\n\n\n\n#### Download Image Encoder\n\nFor NiT-S (33M) model, we use RADIO-v2.5-H as image encoder for REPA-loss.\nFor other NiT models, we use RADIO-v2.5-H as our image encoder.\n\n```bash\nwget -c \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar\" -O checkpoints/radio_v2.5-h.pth.tar\n\nwget -c \"https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar\" -O checkpoints/radio-v2.5-b_half.pth.tar\n```\n\n\n####  Training Scripts\nThe above steps setup the `packed_json`, `jsonl_dir`, and `latent_dirs` in `configs/c2i/nit_xl_pack_merge_radio_16384.yaml`. \nBefore training, please specify the `image_dir` as the directory of ImageNet1K dataset in your own machine. \nTo train the XL-model (675M): \n```bash\nbash scripts/train/train_xl_model.sh\n```\n\nSpecify the `image_dir` in `configs/c2i/nit_s_pack_merge_radio_65536.yaml` and train the base-model (131M):\n```bash\nbash scripts/train/train_s_model.sh\n```\nSpecify the `image_dir` in `configs/c2i/nit_b_pack_merge_radio_65536.yaml` and train the base-model (131M):\n```bash\nbash scripts/train/train_b_model.sh\n```\nSpecify the `image_dir` in `configs/c2i/nit_l_pack_merge_radio_16384.yaml` and train the base-model (457M):\n```bash\nbash scripts/train/train_l_model.sh\n```\nSpecify the `image_dir` in `configs/c2i/nit_xxl_pack_merge_radio_8192.yaml` and train the xxl-model (1.37B):\n```bash\nbash scripts/train/train_xxl_model.sh\n```\n\n\n\n\n### Citations\nIf you find the project useful, please kindly cite: \n```bibtex\n@article{wang2025native,\n  title={Native-Resolution Image Synthesis}, \n  author={Wang, Zidong and Bai, Lei and Yue, Xiangyu and Ouyang, Wanli and Zhang, Yiyuan},\n  year={2025},\n  eprint={2506.03131},\n  archivePrefix={arXiv},\n  primaryClass={cs.CV}\n}\n```\n\n### License\nThis project is licensed under the Apache-2.0 license.\n"
  },
  {
    "path": "configs/c2i/nit_b_pack_merge_radio_65536.yaml",
    "content": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models.c2i.nit_model.NiT\n    params:\n      class_dropout_prob: 0.1\n      num_classes: 1000\n      depth: 12\n      hidden_size: 768\n      patch_size: 1\n      in_channels: 32\n      num_heads: 12\n      qk_norm: True\n      encoder_depth: 4\n      z_dim: 1280\n      use_checkpoint: False\n  # pretrained_vae:\n  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n  slice_vae: False\n  tile_vae: False\n  # repa encoder\n  enc_type: radio\n  enc_dir: checkpoints/radio_v2.5-h.pth.tar\n  proj_coeff: 1.0\n  # ema\n  use_ema: True\n  ema_decay: 0.9999\n  \ndata:\n  data_type: improved_pack\n  dataset:\n    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json\n    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\n    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']\n    latent_dirs: [\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',\n    ]\n    image_dir: <Your imagenet1k directory>/train\n  dataloader:\n    num_workers: 4\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 'c2i'}}\n  max_train_steps: 2000000\n  checkpointing_steps: 2000\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 5.0e-5\n  learning_rate_base_batch_size: 1\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 0\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [200000, 500000, 100000, 150000]\n"
  },
  {
    "path": "configs/c2i/nit_l_pack_merge_radio_16384.yaml",
    "content": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models.c2i.nit_model.NiT\n    params:\n      class_dropout_prob: 0.1\n      num_classes: 1000\n      depth: 24\n      hidden_size: 1024\n      patch_size: 1\n      in_channels: 32\n      num_heads: 16\n      qk_norm: True\n      encoder_depth: 6\n      z_dim: 1280\n      use_checkpoint: False\n  # pretrained_vae:\n  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n  slice_vae: False\n  tile_vae: False\n  # repa encoder\n  enc_type: radio\n  enc_dir: checkpoints/radio_v2.5-h.pth.tar\n  proj_coeff: 1.0\n  # ema\n  use_ema: True\n  ema_decay: 0.9999\n  \ndata:\n  data_type: improved_pack\n  dataset:\n    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json\n    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\n    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']\n    latent_dirs: [\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',\n    ]\n    image_dir: <Your imagenet1k directory>/train\n  dataloader:\n    num_workers: 4\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 'c2i'}}\n  max_train_steps: 2000000\n  checkpointing_steps: 2000\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 5.0e-5\n  learning_rate_base_batch_size: 4\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 0\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [200000, 500000, 100000, 150000]\n"
  },
  {
    "path": "configs/c2i/nit_s_pack_merge_radio_65536.yaml",
    "content": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models.c2i.nit_model.NiT\n    params:\n      class_dropout_prob: 0.1\n      num_classes: 1000\n      depth: 12\n      hidden_size: 384\n      patch_size: 1\n      in_channels: 32\n      num_heads: 6\n      qk_norm: True\n      encoder_depth: 4\n      z_dim: 768\n      projector_dim: 768\n      use_checkpoint: False\n  # pretrained_vae:\n  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n  slice_vae: False\n  tile_vae: False\n  # repa encoder\n  enc_type: radio\n  enc_dir: checkpoints/radio-v2.5-b_half.pth.tar\n  proj_coeff: 1.0\n  # ema\n  use_ema: True\n  ema_decay: 0.9999\n  \ndata:\n  data_type: improved_pack\n  dataset:\n    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json\n    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\n    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']\n    latent_dirs: [\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',\n    ]\n    image_dir: <Your imagenet1k directory>/train\n  dataloader:\n    num_workers: 4\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 'c2i'}}\n  max_train_steps: 2000000\n  checkpointing_steps: 2000\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 5.0e-5\n  learning_rate_base_batch_size: 1\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 0\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [200000, 500000, 100000, 150000]\n"
  },
  {
    "path": "configs/c2i/nit_xl_pack_merge_radio_16384.yaml",
    "content": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models.c2i.nit_model.NiT\n    params:\n      class_dropout_prob: 0.1\n      num_classes: 1000\n      depth: 28\n      hidden_size: 1152\n      patch_size: 1\n      in_channels: 32\n      num_heads: 16\n      qk_norm: True\n      encoder_depth: 8\n      z_dim: 1280\n      use_checkpoint: False\n  # pretrained_vae:\n  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n  slice_vae: False\n  tile_vae: False\n  # repa encoder\n  enc_type: radio\n  enc_dir: checkpoints/radio_v2.5-h.pth.tar\n  proj_coeff: 1.0\n  # ema\n  use_ema: True\n  ema_decay: 0.9999\n  \ndata:\n  data_type: improved_pack\n  dataset:\n    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json\n    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\n    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']\n    latent_dirs: [\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',\n    ]\n    image_dir: <Your imagenet1k directory>/train\n  dataloader:\n    num_workers: 4\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 'c2i'}}\n  max_train_steps: 2000000\n  checkpointing_steps: 2000\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 5.0e-5\n  learning_rate_base_batch_size: 4\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 0\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [200000, 500000, 100000, 150000]\n"
  },
  {
    "path": "configs/c2i/nit_xxl_pack_merge_radio_8192.yaml",
    "content": "model: \n  transport:\n    path_type: linear\n    prediction: v\n    weighting: lognormal\n  network:\n    target: nit.models.c2i.nit_model.NiT\n    params:\n      class_dropout_prob: 0.1\n      num_classes: 1000\n      depth: 40\n      hidden_size: 1536\n      patch_size: 1\n      in_channels: 32\n      num_heads: 24\n      qk_norm: True\n      encoder_depth: 8\n      z_dim: 1280\n      use_checkpoint: False\n      use_adaln_lora: True\n      adaln_lora_dim: 512\n  # pretrained_vae:\n  vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n  slice_vae: False\n  tile_vae: False\n  # repa encoder\n  enc_type: radio\n  enc_dir: checkpoints/radio_v2.5-h.pth.tar\n  proj_coeff: 1.0\n  # ema\n  use_ema: True\n  ema_decay: 0.9999\n  \ndata:\n  data_type: improved_pack\n  dataset:\n    packed_json: datasets/imagenet1k/sampler_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json\n    jsonl_dir: datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\n    data_types: ['native-resolution', 'fixed-256x256', 'fixed-512x512']\n    latent_dirs: [\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256',\n      'datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512',\n    ]\n    image_dir: <Your imagenet1k directory>/train\n  dataloader:\n    num_workers: 4\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 'c2i'}}\n  max_train_steps: 1000000\n  checkpointing_steps: 2000\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 5.0e-5\n  learning_rate_base_batch_size: 4\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 0\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [200000, 500000, 100000]\n"
  },
  {
    "path": "configs/preprocess/imagenet1k_256x256.yaml",
    "content": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/train\n    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256\n    resolution: 256\n  dataloader:\n    num_workers: 1\n    batch_size: 64  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 't2i'}}\n  max_train_steps: 100000\n  checkpointing_steps: 200\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 1.0e-4\n  learning_rate_base_batch_size: 256\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 4000\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [20000, 40000, 60000, 80000]\n"
  },
  {
    "path": "configs/preprocess/imagenet1k_512x512.yaml",
    "content": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/train\n    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512\n    resolution: 512\n  dataloader:\n    num_workers: 1\n    batch_size: 16  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 't2i'}}\n  max_train_steps: 100000\n  checkpointing_steps: 200\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 1.0e-4\n  learning_rate_base_batch_size: 256\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 4000\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [20000, 40000, 60000, 80000]\n"
  },
  {
    "path": "configs/preprocess/imagenet1k_native_resolution.yaml",
    "content": "model:\n  vae: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers\n\ndata:\n  dataset:\n    data_dir: <Your imagenet1k directory>/train\n    target_dir: ./datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution\n    min_image_size: 32\n    max_image_size: 2048\n  dataloader:\n    num_workers: 1\n    batch_size: 1  # Batch size (per device) for the training dataloader.\n\n  \n  \ntraining:\n  tracker: null\n  tracker_kwargs: {'wandb': {'group': 't2i'}}\n  max_train_steps: 100000\n  checkpointing_steps: 200\n  checkpoints_total_limit: 2\n  resume_from_checkpoint: latest\n  learning_rate: 1.0e-4\n  learning_rate_base_batch_size: 256\n  scale_lr: True\n  lr_scheduler: constant # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"]\n  lr_warmup_steps: 4000\n  gradient_accumulation_steps: 1\n  optimizer: \n    target: torch.optim.AdamW\n    params:\n      # betas: ${tuple:0.9, 0.999}\n      betas: [0.9, 0.95]\n      weight_decay: 1.0e-2\n      eps: 1.0e-6\n  max_grad_norm: 1.0\n  proportion_empty_prompts: 0.0\n  mixed_precision: bf16 # [\"no\", \"fp16\", \"bf16\"]\n  allow_tf32: True \n  validation_steps: 500\n  checkpoint_list: [20000, 40000, 60000, 80000]\n"
  },
  {
    "path": "nit/data/pack/__init__.py",
    "content": "from .ennlshp import ENNLSHP\nfrom .lpfhp import LPFHP\nfrom .nnlshp import NNLSHP\nfrom .spfhp import SPFHP\n\nimport json\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\ndef get_strategy(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens):\n    def generate_histogram(dataset_seq_lens):\n        histogram = np.zeros(max_seq_len, dtype=np.int64)\n        seq_lens, counts = np.unique(np.array(dataset_seq_lens), return_counts=True)\n        histogram[seq_lens - 1] = counts\n        return histogram\n    histogram = generate_histogram(dataset_seq_lens)\n    if algorithm == \"SPFHP\":\n        strategy = SPFHP(histogram, max_seq_len, max_seq_per_pack)\n    elif algorithm == \"LPFHP\":\n        strategy = LPFHP(histogram, max_seq_len, max_seq_per_pack)\n    elif algorithm == 'ENNLSHP':\n        strategy = ENNLSHP(histogram, max_seq_len, max_seq_per_pack)\n    elif algorithm == 'NNLSHP':\n        strategy = NNLSHP(histogram, max_seq_len, max_seq_per_pack)\n    else:\n        raise NotImplementedError(\"Algorithm type unsupported. Pass one of: LPFHP, SPFHP\")\n    return strategy\n\ndef pack_dataset(algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens, dataset_seq_idxs):\n    dataset_seqs = torch.stack([torch.tensor(dataset_seq_lens), torch.tensor(dataset_seq_idxs)])\n    strategy_set, strategy_repeat_count = get_strategy(\n        algorithm, max_seq_len, max_seq_per_pack, dataset_seq_lens\n    )\n    \n    packed_indices = []\n    run_iters = sum(strategy_repeat_count)\n    progress_bar = tqdm(range(run_iters))\n    for i in range(len(strategy_repeat_count)):\n        strategy = strategy_set[i]\n        for _ in range(strategy_repeat_count[i]):\n            progress_bar.update(1)\n            ref_inds = []\n            for x in strategy:\n                ref_ind = torch.argwhere(dataset_seqs[0] == x)[-1]\n                dataset_seqs[0, ref_ind] = -1\n                ref_inds.append(ref_ind)\n            inds = dataset_seqs[1, ref_inds].ravel()\n            packed_indices.append(inds.tolist())\n    return packed_indices\n\n"
  },
  {
    "path": "nit/data/pack/ennlshp.py",
    "content": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/ennlshp.py\n\"\"\"Extended Non-Negative least squares histogram-packing.\"\"\"\nimport time\nimport numpy as np\nfrom scipy import optimize, stats\nfrom functools import lru_cache\n\n\ndef get_packing_matrix(strategy_set, max_sequence_length):\n    num_strategies = len(strategy_set)\n    A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)\n    for i, strategy in enumerate(strategy_set):\n        for seq_len in strategy:\n            A[seq_len - 1, i] += 1\n    return A\n\n\n@lru_cache(maxsize=None)\ndef get_packing_strategies(start_length, minimum_increment, target_length, depth):\n    gap = target_length - start_length\n    strategies = []\n    # Complete the packing with exactly 1 number\n    if depth == 1:\n        if gap >= minimum_increment:\n            strategies.append([gap])\n    # Complete the sample in \"depth\" steps, recursively\n    else:\n        for new in range(minimum_increment, gap + 1):\n            new_gap = target_length - start_length - new\n            if new_gap == 0:\n                strategies.append([new])\n            else:\n                options = get_packing_strategies(start_length + new, new, target_length, depth - 1)\n                for option in options:\n                    if len(option) > 0:\n                        strategies.append([new] + option)\n    return strategies\n\n\ndef ENNLSHP(histogram, max_sequence_length, max_sequences_per_pack):\n    # List all unique ways of packing to the desired maximum sequence length\n    strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)\n    # Get the packing matrix corresponding to this list of packing strategies\n    A = get_packing_matrix(strategy_set, max_sequence_length)\n    # Weights that penalize the residual by the number of resulting padding tokens.\n    w0 = np.array([x + 1 for x in range(max_sequence_length)])\n    # construct the packing matrix\n    A_bar = np.zeros((2 * max_sequence_length, len(strategy_set) + max_sequence_length), \"d\")\n    # Base weighted matrix\n    A_bar[:max_sequence_length, : len(strategy_set)] = np.expand_dims(w0, -1) * A\n    # Higher weight to avoid positive residual\n    A_bar[max_sequence_length:, : len(strategy_set)] = np.expand_dims(10**6 * np.ones([max_sequence_length]), -1) * A\n    # negative diagonal unity matrix for mapping to residual\n    A_bar[max_sequence_length:, len(strategy_set) :] = np.expand_dims(\n        10**6 * np.ones([max_sequence_length]), -1\n    ) * np.ones((max_sequence_length, max_sequence_length))\n    b_bar = np.zeros(2 * max_sequence_length)\n    # Apply weighting to histogram vector\n    b_bar[:max_sequence_length] = w0 * histogram\n    b_bar[max_sequence_length:] = 10**6 * np.ones([max_sequence_length]) * histogram\n    # Solve the packing problem\n    start = time.time()\n    strategy_residual, rnorm = optimize.nnls(A_bar, b_bar)\n    strategy_repeat_count = strategy_residual[: len(strategy_set)]\n    # Round the floating point solution to nearest integer\n    strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)\n    # Compute the residuals, shape: [max_sequence_length]\n    residual = histogram - A @ strategy_repeat_count\n    # Handle the left-over sequences; that is the positive part of residual\n    unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]\n    for l in unpacked_seqlen:\n        strategy = sorted([l, max_sequence_length - l])  # the depth 1 strategy\n        strategy_index = strategy_set.index(strategy)\n        strategy_repeat_count[strategy_index] += residual[l - 1]\n    # Re-compute the residual with the updated strategy_repeat_count\n    # This should now be strictly < 0\n    residual = histogram - A @ strategy_repeat_count\n    # Add padding based on deficit (negative residual portion of residual)\n    padding = np.where(residual < 0, -residual, 0)\n\n    # Calculate some basic statistics\n    duration = time.time() - start\n    sequence_lengths = np.arange(1, max_sequence_length + 1)\n    old_number_of_samples = histogram.sum()\n    new_number_of_samples = int(strategy_repeat_count.sum())\n    speedup_upper_bound = 1.0 / (\n        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples\n    )\n    num_padding_tokens_packed = (sequence_lengths * padding).sum()\n    efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)\n    print(\n        f\"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\\n\",\n        f\"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\\n\",\n        f\"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\\n\"\n        f\"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.\",\n    )\n    return strategy_set, strategy_repeat_count\n"
  },
  {
    "path": "nit/data/pack/lpfhp.py",
    "content": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/lpfhp.py\n\"\"\"Longest-pack-first histogram-packing.\"\"\"\nfrom collections import defaultdict\nimport numpy as np\nimport time\n\n\ndef add_pack(pack, count, tmp, final, limit, offset, max_sequence_length=512):\n    \"\"\"Filter out packs that reached maximum length or number of components.\"\"\"\n    # sanity checks\n    assert max_sequence_length - sum(pack) == offset, \"Incorrect offset.\"\n    assert offset >= 0, \"Too small offset.\"\n    assert offset < max_sequence_length, \"Too large offset.\"\n    if len(pack) == limit or offset == 0:\n        final[offset].append((count, pack))\n    else:\n        tmp[offset].append((count, pack))\n\n\ndef LPFHP(histogram, max_sequence_length, max_sequences_per_pack, distribute=True):\n    \"\"\"Longest-pack-first histogram-packing.\"\"\"\n    start = time.time()\n    reversed_histogram = np.flip(histogram)\n    # Initialize main strategy data dictionary.\n    # The key indicates how many tokens are left for full length.\n    # The value is a list of tuples, consisting of counts and respective packs.\n    # A pack is a (sorted) list of sequence length values that get concatenated.\n    tmp_strategies_per_length = defaultdict(list)\n    strategies_per_length = defaultdict(list)\n    if max_sequences_per_pack == \"max\":\n        max_sequences_per_pack = max_sequence_length\n    # Index i indicates here, how much space is left, due to reversed histogram\n    for i in range(max_sequence_length):\n        n_sequences_to_bin = reversed_histogram[i]\n        length_to_bin = max_sequence_length - i\n        offset = 0  # smallest possible offset for perfect fit\n        while n_sequences_to_bin > 0:\n            if (length_to_bin + offset) in tmp_strategies_per_length:\n                # extract worst pack that will get modified\n                n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()\n                # calculate how often the current sequence maximally fits in\n                repeat = min(1 + offset // length_to_bin, max_sequences_per_pack - len(pack))\n                # correct dependent on count\n                while n_sequences_to_bin // repeat == 0:\n                    repeat -= 1\n                if not distribute:\n                    repeat = 1\n                new_pack = pack + [length_to_bin] * repeat\n                count = min(n_sequences_to_pack, n_sequences_to_bin // repeat)\n                if n_sequences_to_pack > count:\n                    # old pack gets reduced\n                    n_sequences_to_pack -= count\n                    tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))\n                    n_sequences_to_bin -= count * repeat\n                else:\n                    n_sequences_to_bin -= n_sequences_to_pack * repeat\n                add_pack(\n                    new_pack,\n                    count,\n                    tmp_strategies_per_length,\n                    strategies_per_length,\n                    max_sequences_per_pack,\n                    offset - (repeat - 1) * length_to_bin,\n                    max_sequence_length,\n                )\n                # clean up to speed up main key search\n                if not tmp_strategies_per_length[length_to_bin + offset]:\n                    tmp_strategies_per_length.pop(length_to_bin + offset)\n                # reset offset in case best fit changed\n                offset = 0\n            else:\n                offset += 1\n            # Does not fit anywhere. Create new pack.\n            if offset >= max_sequence_length - length_to_bin + 1:\n                # similar repetition but no dependence on pack.\n                repeat = min(max_sequence_length // length_to_bin, max_sequences_per_pack)\n                while n_sequences_to_bin // repeat == 0:\n                    repeat -= 1\n                if not distribute:\n                    repeat = 1\n                add_pack(\n                    [length_to_bin] * repeat,\n                    n_sequences_to_bin // repeat,\n                    tmp_strategies_per_length,\n                    strategies_per_length,\n                    max_sequences_per_pack,\n                    max_sequence_length - length_to_bin * repeat,\n                    max_sequence_length,\n                )\n                n_sequences_to_bin -= n_sequences_to_bin // repeat * repeat\n    # merge all strategies\n    for key in tmp_strategies_per_length:\n        strategies_per_length[key].extend(tmp_strategies_per_length[key])\n    # flatten strategies dictionary\n    strategy_set = []\n    strategy_repeat_count = []\n    for key in strategies_per_length:\n        for count, pack in strategies_per_length[key]:\n            pack.reverse()\n            strategy_set.append(pack)\n            strategy_repeat_count.append(count)\n\n    # Summarize efficiency of solution\n    duration = time.time() - start\n    sequence_lengths = np.arange(1, max_sequence_length + 1)\n    strategy_repeat_count = np.array(strategy_repeat_count)\n    n_strategies = len(strategy_set)\n    old_number_of_samples = histogram.sum()\n    new_number_of_samples = strategy_repeat_count.sum()\n    sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])\n    total_tokens = max_sequence_length * new_number_of_samples\n    empty_tokens = sum(\n        [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]\n    )\n    efficiency = 100 - empty_tokens / total_tokens * 100\n    speedup_upper_bound = 1.0 / (\n        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples\n    )\n\n    print(\n        f\"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\\n\",\n        f\"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\\n\",\n        f\"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\",\n        f\"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.\",\n    )\n\n    return strategy_set, strategy_repeat_count\n"
  },
  {
    "path": "nit/data/pack/nnlshp.py",
    "content": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/nnlshp.py\n\"\"\"Non-Negative least squares histogram-packing.\"\"\"\nimport time\nimport numpy as np\nfrom scipy import optimize, stats\nfrom functools import lru_cache\n\n\ndef get_packing_matrix(strategy_set, max_sequence_length):\n    num_strategies = len(strategy_set)\n    A = np.zeros((max_sequence_length, num_strategies), dtype=np.int32)\n    for i, strategy in enumerate(strategy_set):\n        for seq_len in strategy:\n            A[seq_len - 1, i] += 1\n    return A\n\n\n@lru_cache(maxsize=None)\ndef get_packing_strategies(start_length, minimum_increment, target_length, depth):\n    gap = target_length - start_length\n    strategies = []\n    # Complete the packing with exactly 1 number\n    if depth == 1:\n        if gap >= minimum_increment:\n            strategies.append([gap])\n    # Complete the sample in \"depth\" steps, recursively\n    else:\n        for new in range(minimum_increment, gap + 1):\n            new_gap = target_length - start_length - new\n            if new_gap == 0:\n                strategies.append([new])\n            else:\n                options = get_packing_strategies(start_length + new, new, target_length, depth - 1)\n                for option in options:\n                    if len(option) > 0:\n                        strategies.append([new] + option)\n    return strategies\n\n\ndef NNLSHP(histogram, max_sequence_length, max_sequences_per_pack):\n    # List all unique ways of packing to the desired maximum sequence length\n    strategy_set = get_packing_strategies(0, 1, max_sequence_length, max_sequences_per_pack)\n    # Get the packing matrix corresponding to this list of packing strategies\n    A = get_packing_matrix(strategy_set, max_sequence_length)\n    # Weights that penalize the residual on short sequences less.\n    penalization_cutoff = 8\n    w0 = np.ones([max_sequence_length])\n    w0[:penalization_cutoff] = 0.09\n    # Solve the packing problem\n    start = time.time()\n    strategy_repeat_count, rnorm = optimize.nnls(np.expand_dims(w0, -1) * A, w0 * histogram)\n    # Round the floating point solution to nearest integer\n    strategy_repeat_count = np.rint(strategy_repeat_count).astype(np.int64)\n    # Compute the residuals, shape: [max_sequence_length]\n    residual = histogram - A @ strategy_repeat_count\n    # Handle the left-over sequences, that is the positive part of residual\n    unpacked_seqlen = np.arange(1, max_sequence_length + 1)[residual > 0]\n    for l in unpacked_seqlen:\n        strategy = sorted([l, max_sequence_length - l])  # the depth 1 strategy\n        strategy_index = strategy_set.index(strategy)\n        strategy_repeat_count[strategy_index] += residual[l - 1]\n    # Re-compute the residual with the updated strategy_repeat_count\n    # This should now be strictly < 0\n    residual = histogram - A @ strategy_repeat_count\n    # Add padding based on deficit (negative residual portion of residual)\n    padding = np.where(residual < 0, -residual, 0)\n\n    # Calculate some basic statistics\n    duration = time.time() - start\n    sequence_lengths = np.arange(1, max_sequence_length + 1)\n    old_number_of_samples = histogram.sum()\n    new_number_of_samples = int(strategy_repeat_count.sum())\n    speedup_upper_bound = 1.0 / (\n        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples\n    )\n    num_padding_tokens_packed = (sequence_lengths * padding).sum()\n    efficiency = 1 - num_padding_tokens_packed / (new_number_of_samples * max_sequence_length)\n    print(\n        f\"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\\n\",\n        f\"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\\n\",\n        f\"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\\n\"\n        f\"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.\",\n    )\n\n    return strategy_set, strategy_repeat_count\n"
  },
  {
    "path": "nit/data/pack/spfhp.py",
    "content": "# Copyright (c) 2021 Graphcore Ltd. All rights reserved.\n# modified from https://github.com/graphcore/examples/blob/v3.2.0/tutorials/blogs_code/packedBERT/spfhp.py\n\"\"\"Shortest-pack-first histogram-packing.\"\"\"\nfrom collections import defaultdict\nimport numpy as np\nimport time\n\n\ndef add_pack(pack, count, tmp, final, limit, offset):\n    \"\"\"Filter out packs that reached maximum length or number of sequences.\"\"\"\n    if len(pack) == limit or offset == 0:\n        final[offset].append((count, pack))\n    else:\n        tmp[offset].append((count, pack))\n\n\ndef SPFHP(histogram, max_sequence_length, max_sequences_per_pack):\n    \"\"\"Shortest-pack-first histogram-packing.\"\"\"\n    start = time.time()\n    reversed_histogram = np.flip(histogram)\n    # Initialize main strategy data dictionary.\n    # The key indicates how many tokens are left for full length.\n    # The value is a list of tuples, consisting of counts and respective packs.\n    # A pack is a (sorted) list of sequence length values that get concatenated.\n    tmp_strategies_per_length = defaultdict(list)\n    strategies_per_length = defaultdict(list)\n    # Index i indicates here, how much space is left, due to reversed histogram\n    for i in range(max_sequence_length):\n        n_sequences_to_bin = reversed_histogram[i]\n        length_to_bin = max_sequence_length - i\n        offset = i + 1  # largest possible offset\n        while n_sequences_to_bin > 0:\n            if (length_to_bin + offset) in tmp_strategies_per_length:\n                # extract shortest pack that will get modified\n                n_sequences_to_pack, pack = tmp_strategies_per_length[length_to_bin + offset].pop()\n                new_pack = pack + [length_to_bin]\n                count = min(n_sequences_to_pack, n_sequences_to_bin)\n                if n_sequences_to_pack > n_sequences_to_bin:\n                    # old pack gets reduced\n                    n_sequences_to_pack -= n_sequences_to_bin\n                    tmp_strategies_per_length[length_to_bin + offset].append((n_sequences_to_pack, pack))\n                    n_sequences_to_bin = 0\n                else:\n                    n_sequences_to_bin -= n_sequences_to_pack\n                add_pack(\n                    new_pack,\n                    count,\n                    tmp_strategies_per_length,\n                    strategies_per_length,\n                    max_sequences_per_pack,\n                    offset,\n                )\n                # clean up to speed up main key search\n                if not tmp_strategies_per_length[length_to_bin + offset]:\n                    tmp_strategies_per_length.pop(length_to_bin + offset)\n            else:\n                offset -= 1\n            # Does not fit anywhere. Create new pack.\n            if offset < 0:\n                add_pack(\n                    [length_to_bin],\n                    n_sequences_to_bin,\n                    tmp_strategies_per_length,\n                    strategies_per_length,\n                    max_sequences_per_pack,\n                    i,\n                )\n                n_sequences_to_bin = 0\n    # merge all strategies\n    for key in tmp_strategies_per_length:\n        strategies_per_length[key].extend(tmp_strategies_per_length[key])\n    # flatten strategies dictionary\n    strategy_set = []\n    strategy_repeat_count = []\n    for key in strategies_per_length:\n        for count, pack in strategies_per_length[key]:\n            pack.reverse()\n            strategy_set.append(pack)\n            strategy_repeat_count.append(count)\n\n    # Summarize efficiency of solution\n    duration = time.time() - start\n    sequence_lengths = np.arange(1, max_sequence_length + 1)\n    strategy_repeat_count = np.array(strategy_repeat_count)\n    n_strategies = len(strategy_set)\n    old_number_of_samples = histogram.sum()\n    new_number_of_samples = strategy_repeat_count.sum()\n    sequences = sum([count * len(pack) for count, pack in zip(strategy_repeat_count, strategy_set)])\n    total_tokens = max_sequence_length * new_number_of_samples\n    empty_tokens = sum(\n        [count * (max_sequence_length - sum(pack)) for count, pack in zip(strategy_repeat_count, strategy_set)]\n    )\n    efficiency = 100 - empty_tokens / total_tokens * 100\n    speedup_upper_bound = 1.0 / (\n        1 - (histogram * (1 - sequence_lengths / max_sequence_length)).sum() / old_number_of_samples\n    )\n\n    print(\n        f\"Packing efficiency (fraction of real tokens): {efficiency:3.4f}\\n\",\n        f\"Speed-up theoretical limit: {speedup_upper_bound:3.4f}\\n\",\n        f\"Achieved speed-up over un-packed dataset: {old_number_of_samples/new_number_of_samples:3.5f}\\n\",\n        f\"Runtime: Packed {old_number_of_samples} sequences in {duration:3.3f} seconds.\",\n    )\n\n    return strategy_set, np.array(strategy_repeat_count)\n"
  },
  {
    "path": "nit/data/packed_c2i_data.py",
    "content": "import os\nimport datetime\nimport torchvision\nimport numpy as np\nimport torch\nimport ast\nimport json\nimport time\n\n\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.datasets import ImageFolder\nfrom torchvision import transforms\nfrom accelerate.logging import get_logger\nfrom safetensors.torch import load_file\nfrom einops import rearrange\nfrom functools import partial\nfrom torchvision.transforms.functional import hflip\n\nfrom .sampler_util import get_train_sampler, get_packed_batch_sampler\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\nPATCH_SIZE = 1\n\ndef resize_arr(pil_image, height, width):\n    pil_image = pil_image.resize((width, height), resample=Image.Resampling.BICUBIC)\n\n    return pil_image\n\ndef center_crop_arr(pil_image, image_size):\n    \"\"\"\n    Center cropping implementation from ADM.\n    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126\n    \"\"\"\n    while min(*pil_image.size) >= 2 * image_size:\n        pil_image = pil_image.resize(\n            tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX\n        )\n\n    scale = image_size / min(*pil_image.size)\n    pil_image = pil_image.resize(\n        tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC\n    )\n\n    arr = np.array(pil_image)\n    crop_y = (arr.shape[0] - image_size) // 2\n    crop_x = (arr.shape[1] - image_size) // 2\n    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])\n\ndef packed_collate_fn(batch):\n    packed_latent = []\n    label = []\n    hw_list = []\n    image = []\n    for data in batch:\n        C, H, W = data['latent'].shape\n        latent = rearrange(\n            data['latent'], 'c (h p1) (w p2) -> (h w) c p1 p2', p1=PATCH_SIZE, p2=PATCH_SIZE\n        )\n        packed_latent.append(latent)\n        hw_list.append([H/PATCH_SIZE, W/PATCH_SIZE])\n        label.append(data['label'])\n        image.append(data['image'])\n    packed_latent = torch.concat(packed_latent)\n    label = torch.tensor(label)\n    hw_list = torch.tensor(hw_list, dtype=torch.int32)\n    return dict(image=image, latent=packed_latent, label=label, hw_list=hw_list)\n\n\n\nclass ImprovedPackedImageNetLatentDataset(Dataset):\n    def __init__(self, packed_json, jsonl_dir, data_types, latent_dirs, image_dir):\n        super().__init__()\n        assert len(data_types) == len(latent_dirs)\n        self.type_to_dir = dict()\n        for i, data_type in enumerate(data_types):\n            self.type_to_dir[data_type] = latent_dirs[i]\n        self.image_dir = image_dir\n\n        with open(packed_json, 'r') as fp:\n            self.packed_dataset = json.load(fp)\n\n        with open(jsonl_dir, 'r') as fp:\n            self.dataset = [json.loads(line) for line in fp]\n        \n    \n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        data_meta = self.dataset[index]\n        \n        data_item = dict()\n        data_type = data_meta['type']\n        latent_file = os.path.join(self.type_to_dir[data_type], data_meta['latent_file'])\n        image_file = os.path.join(self.image_dir, data_meta['image_file'])\n\n        data = load_file(latent_file)\n        \n        height = data_meta['latent_h'] * 16\n        width = data_meta['latent_w'] * 16\n        \n        if data_type == 'native-resolution':\n            preprocess = partial(resize_arr, height=height, width=width)\n        else:\n            assert height == width\n            preprocess = partial(center_crop_arr, image_size=height)\n\n        transform = transforms.Compose([\n            transforms.Lambda(lambda pil_image: preprocess(pil_image=pil_image)),\n            transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))),\n            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),\n        ])\n\n        rand_idx = torch.randint(low=0, high=2, size=(1,)).item()\n        data_item['image'] = transform(Image.open(image_file).convert(\"RGB\"))[rand_idx]\n        data_item['latent'] = data['latent'][rand_idx]\n        data_item['label'] = data['label']\n        return data_item\n\n\nclass C2ILoader():\n    def __init__(self, data_config):\n        super().__init__()\n\n        self.batch_size = data_config.dataloader.batch_size\n        self.num_workers = data_config.dataloader.num_workers\n\n        self.data_type = data_config.data_type\n        \n    \n        if data_config.data_type == 'improved_pack':\n            self.train_dataset = ImprovedPackedImageNetLatentDataset(\n                **OmegaConf.to_container(data_config.dataset)\n            )\n        else:\n            raise NotImplementedError\n        \n        \n        self.test_dataset = None\n        self.val_dataset = None\n\n    def train_len(self):\n        return len(self.train_dataset)\n\n    def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed):\n        sampler = get_train_sampler(\n            self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed\n        )\n        if self.data_type == 'improved_pack':\n            batch_sampler = get_packed_batch_sampler(\n                self.train_dataset.packed_dataset, rank, world_size, max_steps, resume_steps, seed\n            )\n            return DataLoader(\n                self.train_dataset,\n                batch_sampler=batch_sampler,\n                collate_fn=packed_collate_fn,\n                num_workers=self.num_workers,\n                pin_memory=True,\n            )\n        else:\n            return DataLoader(\n                self.train_dataset,\n                batch_size=self.batch_size,\n                sampler=sampler,\n                num_workers=self.num_workers,\n                pin_memory=True,\n                drop_last=True,\n            )\n    def test_dataloader(self):\n        return None\n\n    def val_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=True\n        )\n\n\n\n\n"
  },
  {
    "path": "nit/data/sampler_util.py",
    "content": "import torch\nimport json\n\n# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60\ndef get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps,\n                      resume_step, seed):\n    sample_indices = torch.empty([max_steps * global_batch_size // world_size],\n                                 dtype=torch.long)\n    epoch_id, fill_ptr, offs = 0, 0, 0\n    while fill_ptr < sample_indices.size(0):\n        g = torch.Generator()\n        g.manual_seed(seed + epoch_id)\n        epoch_sample_indices = torch.randperm(len(dataset), generator=g)\n        epoch_id += 1\n        epoch_sample_indices = epoch_sample_indices[\n            (rank + offs) % world_size::world_size\n        ]\n        offs = (offs + world_size - len(dataset) % world_size) % world_size\n        epoch_sample_indices = epoch_sample_indices[\n            :sample_indices.size(0) - fill_ptr\n        ]\n        sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = \\\n            epoch_sample_indices\n        fill_ptr += epoch_sample_indices.size(0)\n    return sample_indices[resume_step * global_batch_size // world_size:].tolist()\n\n\n\n\ndef get_packed_batch_sampler(\n        dataset, rank, world_size, max_steps, resume_step, seed\n    ):\n    sample_indices = [None for _ in range(max_steps)]\n    epoch_id, fill_ptr, offs = 0, 0, 0\n    while fill_ptr < len(sample_indices):\n        g = torch.Generator()\n        g.manual_seed(seed + epoch_id)\n        epoch_sample_indices = torch.randperm(len(dataset), generator=g)\n        epoch_id += 1\n        epoch_sample_indices = epoch_sample_indices[\n            (rank + offs) % world_size::world_size\n        ]\n        offs = (offs + world_size - len(dataset) % world_size) % world_size\n        epoch_sample_indices = epoch_sample_indices[\n            :len(sample_indices) - fill_ptr\n        ]\n        sample_indices[fill_ptr: fill_ptr + epoch_sample_indices.size(0)] = [\n            dataset[i] for i in epoch_sample_indices\n        ]\n        fill_ptr += epoch_sample_indices.size(0)\n    return sample_indices[resume_step:]\n\n"
  },
  {
    "path": "nit/models/c2i/nit_model.py",
    "content": "# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n# --------------------------------------------------------\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport math\nfrom timm.models.vision_transformer import PatchEmbed, Mlp\nfrom einops import rearrange, repeat\nfrom flash_attn import flash_attn_varlen_func\nfrom nit.models.utils.funcs import get_parameter_dtype\nfrom nit.models.utils.pos_embeds.rope import VisionRotaryEmbedding, rotate_half\nfrom typing import Optional\n\ndef modulate(x, shift, scale):\n    return x * (1 + scale) + shift\n\ndef build_mlp(hidden_size, projector_dim, z_dim):\n    return nn.Sequential(\n                nn.Linear(hidden_size, projector_dim),\n                nn.SiLU(),\n                nn.Linear(projector_dim, projector_dim),\n                nn.SiLU(),\n                nn.Linear(projector_dim, z_dim),\n            )\n#################################################################################\n#               Embedding Layers for Timesteps and Class Labels                 #\n#################################################################################            \nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n    def __init__(self, hidden_size, frequency_embedding_size=256):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, hidden_size, bias=True),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n    \n    @staticmethod\n    def positional_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        return embedding\n\n    def forward(self, t):\n        self.timestep_embedding = self.positional_embedding\n        t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\n\nclass LabelEmbedder(nn.Module):\n    \"\"\"\n    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.\n    \"\"\"\n    def __init__(self, num_classes, hidden_size, dropout_prob):\n        super().__init__()\n        use_cfg_embedding = dropout_prob > 0\n        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)\n        self.num_classes = num_classes\n        self.dropout_prob = dropout_prob\n\n    def forward(self, labels):\n        embeddings = self.embedding_table(labels)\n        return embeddings\n\n\n#################################################################################\n#                                 Attention Block                               #\n#################################################################################\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim: int,\n            num_heads: int = 8,\n            qkv_bias: bool = False,\n            qk_norm: bool = False,\n            attn_drop: float = 0.,\n            proj_drop: float = 0.,\n            norm_layer: nn.Module = nn.LayerNorm,\n    ) -> None:\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: torch.Tensor, cu_seqlens, freqs_cos, freqs_sin) -> torch.Tensor:\n        N, C = x.shape\n        qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)\n        ori_dtype = qkv.dtype\n        q, k, v = qkv.unbind(0)\n        q, k = self.q_norm(q), self.k_norm(k)\n        \n        q = q * freqs_cos + rotate_half(q) * freqs_sin\n        k = k * freqs_cos + rotate_half(k) * freqs_sin\n        q, k = q.to(ori_dtype), k.to(ori_dtype)\n        \n        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()\n\n        x = flash_attn_varlen_func(\n            q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen\n        ).reshape(N, -1)\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\n\n#################################################################################\n#                                 Core NiT Model                                #\n#################################################################################\n\nclass NiTBlock(nn.Module):\n    \"\"\"\n    A NiT block with adaptive layer norm zero (adaLN-Zero) conditioning.\n    \"\"\"\n    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.attn = Attention(\n            hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs[\"qk_norm\"]\n        )\n        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        approx_gelu = lambda: nn.GELU(approximate=\"tanh\")\n        self.mlp = Mlp(\n            in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0\n            )\n        use_adaln_lora = block_kwargs.get('use_adaln_lora', False)\n        if use_adaln_lora:\n            adaln_lora_dim = block_kwargs['adaln_lora_dim']\n            self.adaLN_modulation = nn.Sequential(\n                nn.SiLU(),\n                nn.Linear(hidden_size, adaln_lora_dim, bias=True),\n                nn.Linear(adaln_lora_dim, 6 * hidden_size, bias=True)\n            )\n        else:\n            self.adaLN_modulation = nn.Sequential(\n                nn.SiLU(),\n                nn.Linear(hidden_size, 6 * hidden_size, bias=True)\n            )\n\n    def forward(self, x, c, cu_seqlens, freqs_cos, freqs_sin):\n        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (\n            self.adaLN_modulation(c).chunk(6, dim=-1)\n        )\n        x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), cu_seqlens, freqs_cos, freqs_sin)\n        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))\n\n        return x\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of NiT.\n    \"\"\"\n    def __init__(self, hidden_size, patch_size, out_channels):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)\n        self.adaLN_modulation = nn.Sequential(\n            nn.SiLU(),\n            nn.Linear(hidden_size, 2 * hidden_size, bias=True)\n        )\n\n    def forward(self, x, c):\n        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)\n        x = modulate(self.norm_final(x), shift, scale)\n        x = self.linear(x)\n\n        return x\n\n\nclass NiT(nn.Module):\n    \"\"\"\n    Diffusion model with a Transformer backbone.\n    \"\"\"\n    def __init__(\n        self,\n        input_size=32,\n        patch_size=2,\n        in_channels=4,\n        hidden_size=1152,\n        depth=28,\n        num_heads=16,\n        mlp_ratio=4.0,\n        class_dropout_prob=0.1,\n        num_classes=1000,\n        encoder_depth=4,\n        projector_dim=2048,\n        z_dim=768,\n        use_checkpoint: bool = False,\n        custom_freqs: str = 'normal',\n        theta: int = 10000,\n        max_pe_len_h: Optional[int] = None,\n        max_pe_len_w: Optional[int] = None,\n        decouple: bool = False,\n        ori_max_pe_len: Optional[int] = None,\n        **block_kwargs # fused_attn\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = in_channels\n        self.patch_size = patch_size\n        self.num_heads = num_heads\n        self.num_classes = num_classes\n        self.encoder_depth = encoder_depth\n        self.use_checkpoint = use_checkpoint\n        \n        self.x_embedder = PatchEmbed(\n            input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=False\n        )\n        self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type\n        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)\n        self.rope = VisionRotaryEmbedding(\n            head_dim=hidden_size//num_heads, custom_freqs=custom_freqs, theta=theta,\n            max_pe_len_h=max_pe_len_h, max_pe_len_w=max_pe_len_w, decouple=decouple,\n            ori_max_pe_len=ori_max_pe_len\n        )\n\n        self.projector = build_mlp(hidden_size, projector_dim, z_dim) \n        \n        self.blocks = nn.ModuleList([\n            NiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)\n        ])\n        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        # Initialize transformer layers:\n        def _basic_init(module):\n            if isinstance(module, nn.Linear):\n                torch.nn.init.xavier_uniform_(module.weight)\n                if module.bias is not None:\n                    nn.init.constant_(module.bias, 0)\n        self.apply(_basic_init)\n\n        \n        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):\n        w = self.x_embedder.proj.weight.data\n        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))\n        nn.init.constant_(self.x_embedder.proj.bias, 0)\n\n        # Initialize label embedding table:\n        nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)\n\n        # Initialize timestep embedding MLP:\n        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)\n        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)\n\n        # Zero-out adaLN modulation layers in NiT blocks:\n        for block in self.blocks:\n            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)\n            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)\n\n        # Zero-out output layers:\n        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)\n        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)\n        nn.init.constant_(self.final_layer.linear.weight, 0)\n        nn.init.constant_(self.final_layer.linear.bias, 0)\n\n    def unpatchify(self, x, patch_size=None):\n        \"\"\"\n        x: (N, T, patch_size**2 * C)\n        imgs: (N, H, W, C)\n        \"\"\"\n        c = self.out_channels\n        p = self.x_embedder.patch_size[0] if patch_size is None else patch_size\n        h = w = int(x.shape[1] ** 0.5)\n        assert h * w == x.shape[1]\n\n        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))\n        x = torch.einsum('nhwpqc->nchpwq', x)\n        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))\n        return imgs\n    \n    def get_rope(self, hw_list):\n        grids = []\n        for h, w in hw_list:\n            grid_h = torch.arange(h)\n            grid_w = torch.arange(w)\n            grid = torch.meshgrid(grid_h, grid_w, indexing='xy') \n            grid = torch.stack(grid, dim=0).reshape(2, -1)\n            grids.append(grid)\n        grids = torch.cat(grids, dim=-1)\n        freqs_cos, freqs_sin = self.rope.get_cached_2d_rope_from_grid(grids)\n        return freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)\n\n    def forward(self, x, t, y, hw_list, return_zs=False, return_logvar=False):\n        \"\"\"\n        Forward pass of NiT.\n        x: (N, C, p, p) tensor of spatial inputs (images or latent representations of images)\n        t: (N,) tensor of diffusion timesteps\n        y: (N,) tensor of class labels\n        \"\"\"\n        x = self.x_embedder(x)                  # (N, C, p, p) -> (N, 1, D), where T = H * W / patch_size ** 2\n        x = x.squeeze(1)                        # (N, D)\n        B = hw_list.shape[0]\n\n        freqs_cos, freqs_sin = self.get_rope(hw_list)   # (N, D_h)\n        seqlens = hw_list[:, 0] * hw_list[:, 1]\n        cu_seqlens = torch.cat([\n            torch.tensor([0], device=hw_list.device, dtype=torch.int), \n            torch.cumsum(seqlens, dim=0, dtype=torch.int)\n        ])\n\n        # timestep and class embedding\n        t_embed = self.t_embedder(t)            # (B, D)\n        y = self.y_embedder(y)                  # (B, D)\n        c = t_embed + y                         # (B, D)\n        \n        # (B, D) -> (N, D)\n        c = torch.cat([c[i].unsqueeze(0).repeat(seqlens[i], 1) for i in range(B)], dim=0)\n        \n        zs=[]\n        for i, block in enumerate(self.blocks):\n            if not self.use_checkpoint:\n                x = block(x, c, cu_seqlens, freqs_cos, freqs_sin)   # (N, D)\n            else:\n                x = torch.utils.checkpoint.checkpoint(\n                    self.ckpt_wrapper(block), x, c, cu_seqlens, freqs_cos, freqs_sin\n                )  \n            if (i + 1) == self.encoder_depth and return_zs:\n                zs = [self.projector(x)]\n        x = self.final_layer(x, c)              # (N, out_channels * patch_size ** 2)\n        \n        # (N, out_channels * patch_size ** 2) -> (N, out_channels, p, p)\n        x = rearrange(x, 'n (c p1 p2) -> n c p1 p2', p1=self.patch_size, p2=self.patch_size)                  \n        if return_zs:\n            return x, zs\n        else:\n            return x  \n\n\n    def ckpt_wrapper(self, module):\n        def ckpt_forward(*inputs):\n            outputs = module(*inputs)\n            return outputs\n        return ckpt_forward\n    \n    @property\n    def dtype(self) -> torch.dtype:\n        \"\"\"\n        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).\n        \"\"\"\n        return get_parameter_dtype(self)\n\n"
  },
  {
    "path": "nit/models/nvidia_radio/hubconf.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\ndependencies = [\"torch\", \"timm\", \"einops\"]\n\nimport os\nfrom typing import Dict, Any, Optional, Union, List\nimport warnings\n\nimport torch\nfrom torch.hub import load_state_dict_from_url\n\nfrom timm.models import clean_state_dict\n\nfrom .radio.adaptor_registry import adaptor_registry\nfrom .radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP\nfrom .radio.enable_damp import configure_damp_from_args\nfrom .radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args\nfrom .radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer\nfrom .radio.radio_model import RADIOModel, create_model_from_args\nfrom .radio.input_conditioner import get_default_conditioner\nfrom .radio.vitdet import apply_vitdet_arch, VitDetArgs\n\n\ndef radio_model(\n    version: str = \"\",\n    progress: bool = True,\n    adaptor_names: Union[str, List[str]] = None,\n    vitdet_window_size: Optional[int] = None,\n    return_checkpoint: bool = False,\n    support_packing: bool=False,\n    **kwargs,\n) -> RADIOModel:\n    if not version:\n        version = DEFAULT_VERSION\n\n    if os.path.isfile(version):\n        chk = torch.load(version, map_location=\"cpu\", weights_only=False)\n        resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)\n    else:\n        resource = RESOURCE_MAP[version]\n        chk = load_state_dict_from_url(\n            resource.url, progress=progress, map_location=\"cpu\", weights_only=False,\n        )\n\n    if \"state_dict_ema\" in chk:\n        state_dict = chk[\"state_dict_ema\"]\n        chk['args'].spectral_reparam = False\n    else:\n        state_dict = chk[\"state_dict\"]\n\n    args = chk[\"args\"]\n    args.support_packing = support_packing\n    mod = create_model_from_args(args)\n\n    mod_state_dict = get_prefix_state_dict(state_dict, \"base_model.\")\n\n    if args.spectral_reparam:\n        configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)\n\n    if getattr(args, 'damp', None):\n        configure_damp_from_args(mod, args)\n\n    state_dict = clean_state_dict(state_dict)\n\n    key_warn = mod.load_state_dict(mod_state_dict, strict=False)\n    if key_warn.missing_keys:\n        warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')\n    if key_warn.unexpected_keys:\n        warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')\n\n    if chk['args'].spectral_reparam:\n        # Spectral reparametrization uses PyTorch's \"parametrizations\" API. The idea behind\n        # the method is that instead of there being a `weight` tensor for certain Linear layers\n        # in the model, we make it a dynamically computed function. During training, this\n        # helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.\n        # Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`\n        # once, during this function call, and replace the parametrization with the realized weights.\n        # This makes the model run faster, and also use less memory.\n        disable_spectral_reparam(mod)\n        chk['args'].spectral_reparam = False\n\n    conditioner = get_default_conditioner()\n    conditioner.load_state_dict(get_prefix_state_dict(state_dict, \"input_conditioner.\"))\n\n    dtype = getattr(chk['args'], 'dtype', torch.float32)\n    mod.to(dtype=dtype)\n    conditioner.dtype = dtype\n\n    cls_token_per_teacher = getattr(chk['args'], 'cls_token_per_teacher', True)\n    if cls_token_per_teacher:\n        name_to_idx_map = dict()\n        for i, t in enumerate(chk['args'].teachers):\n            if t.get('use_summary', True):\n                name = t['name']\n                if name not in name_to_idx_map:\n                    name_to_idx_map[name] = i\n        summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)\n    else:\n        summary_idxs = torch.tensor([0], dtype=torch.int64)\n\n    if adaptor_names is None:\n        adaptor_names = []\n    elif isinstance(adaptor_names, str):\n        adaptor_names = [adaptor_names]\n\n    teachers = chk[\"args\"].teachers\n    adaptors = dict()\n    for adaptor_name in adaptor_names:\n        for tidx, tconf in enumerate(teachers):\n            if tconf[\"name\"] == adaptor_name:\n                break\n        else:\n            raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t[\"name\"] for t in teachers)}')\n\n        ttype = tconf[\"type\"]\n\n        pf_idx_head = f'_heads.{tidx}'\n        pf_name_head = f'_heads.{adaptor_name}'\n        pf_idx_feat = f'_feature_projections.{tidx}'\n        pf_name_feat = f'_feature_projections.{adaptor_name}'\n\n        adaptor_state = dict()\n        for k, v in state_dict.items():\n            if k.startswith(pf_idx_head):\n                adaptor_state['summary' + k[len(pf_idx_head):]] = v\n            elif k.startswith(pf_name_head):\n                adaptor_state['summary' + k[len(pf_name_head):]] = v\n            elif k.startswith(pf_idx_feat):\n                adaptor_state['feature' + k[len(pf_idx_feat):]] = v\n            elif k.startswith(pf_name_feat):\n                adaptor_state['feature' + k[len(pf_name_feat):]] = v\n\n        adaptor = adaptor_registry.create_adaptor(ttype, chk[\"args\"], tconf, adaptor_state)\n        adaptor.head_idx = tidx if cls_token_per_teacher else 0\n        adaptors[adaptor_name] = adaptor\n\n    feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')\n    feature_normalizer = None\n    if feat_norm_sd:\n        feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)\n        feature_normalizer.load_state_dict(feat_norm_sd)\n\n    inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')\n    inter_feature_normalizer = None\n    if inter_feat_norm_sd:\n        inter_feature_normalizer = IntermediateFeatureNormalizer(\n            *inter_feat_norm_sd['means'].shape[:2],\n            rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,\n            dtype=dtype\n        )\n        inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)\n\n    radio = RADIOModel(\n        mod,\n        conditioner,\n        summary_idxs=summary_idxs,\n        patch_size=resource.patch_size,\n        max_resolution=resource.max_resolution,\n        window_size=vitdet_window_size,\n        preferred_resolution=resource.preferred_resolution,\n        adaptors=adaptors,\n        feature_normalizer=feature_normalizer,\n        inter_feature_normalizer=inter_feature_normalizer,\n    )\n\n    if vitdet_window_size is not None:\n        apply_vitdet_arch(\n            mod,\n            VitDetArgs(\n                vitdet_window_size,\n                radio.num_summary_tokens,\n                num_windowed=resource.vitdet_num_windowed,\n                num_global=resource.vitdet_num_global,\n            ),\n        )\n\n    if return_checkpoint:\n        return radio, chk\n    return radio\n\n\ndef get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):\n    mod_state_dict = {\n        k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)\n    }\n    return mod_state_dict\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/__init__.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n# Register the adaptors\nfrom .adaptor_registry import adaptor_registry\nfrom . import open_clip_adaptor\nfrom .adaptor_base import AdaptorInput, RadioOutput, AdaptorBase\n\n# Enable support for other model types via the timm register_model mechanism\nfrom . import extra_timm_models\nfrom . import extra_models\nfrom . import vision_transformer_xpos\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_base.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom argparse import Namespace\nfrom typing import NamedTuple, Optional\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\n\nclass AdaptorInput(NamedTuple):\n    images: torch.Tensor\n    summary: torch.Tensor\n    features: torch.Tensor\n    feature_fmt: str\n    patch_size: int\n\n\nclass RadioOutput(NamedTuple):\n    summary: torch.Tensor\n    features: torch.Tensor\n\n    def to(self, *args, **kwargs):\n        return RadioOutput(\n            self.summary.to(*args, **kwargs) if self.summary is not None else None,\n            self.features.to(*args, **kwargs) if self.features is not None else None,\n        )\n\n\nclass AdaptorBase(nn.Module):\n    def forward(self, input: AdaptorInput) -> RadioOutput:\n        raise NotImplementedError(\"Subclasses must implement this!\")\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_generic.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom argparse import Namespace\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput\nfrom .adaptor_mlp import create_mlp_from_state, create_mlp_from_config\n\n\nclass GenericAdaptor(AdaptorBase):\n    def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):\n        super().__init__()\n\n        extra_args = dict()\n        ups = None\n        ups_rank = None\n        if adaptor_config is not None:\n            ups = adaptor_config.get('fd_upsample_factor', None)\n            ups_rank = adaptor_config.get('fd_upsample_rank', None)\n        elif mlp_config is not None:\n            ups = mlp_config[\"feature\"].get('upsample_factor', None)\n            ups_rank = mlp_config[\"feature\"].get('upsample_rank', None)\n        if ups is not None:\n            extra_args['upsample_factor'] = ups\n            extra_args['upsample_rank'] = ups_rank\n\n        if state is not None:\n            spectral_heads = getattr(main_config, 'spectral_heads', False)\n            self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)\n            self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)\n        else:\n            assert mlp_config is not None, \"Config must not be None if state is None\"\n\n            self.head_mlp =  create_mlp_from_config(\n                main_config.mlp_version,\n                mlp_config[\"summary\"][\"input_dim\"],\n                mlp_config[\"summary\"][\"hidden_dim\"],\n                mlp_config[\"summary\"][\"output_dim\"],\n                mlp_config[\"summary\"][\"num_inner\"],\n            )\n            self.feat_mlp = create_mlp_from_config(\n                main_config.mlp_version,\n                mlp_config[\"feature\"][\"input_dim\"],\n                mlp_config[\"feature\"][\"hidden_dim\"],\n                mlp_config[\"feature\"][\"output_dim\"],\n                mlp_config[\"feature\"][\"num_inner\"],\n                **extra_args\n            )\n\n    def forward(self, input: AdaptorInput) -> RadioOutput:\n        # Convert input'd type to the type of the first parameter of the adaptor.\n        first_param = next(self.parameters())\n        summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)\n        feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)\n\n        if input.feature_fmt == 'NCHW':\n            feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])\n                        .permute(0, 3, 1, 2)\n            )\n\n        return RadioOutput(summary, feat)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_mlp.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nimport math\nfrom typing import Dict, Optional\n\nimport torch\nfrom torch import nn\n\nfrom einops import rearrange\nfrom timm.models.vision_transformer import Block\n\nfrom .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam\n\n\nclass MLP(nn.Module):\n    def __init__(self, input_size: int, hidden_size: int, output_size: int,\n                 num_inner: int = 0, device: torch.device = None, **kwargs):\n        super(MLP, self).__init__()\n        self.fc1 = nn.Linear(input_size, hidden_size, device=device)\n        self.norm = nn.LayerNorm(hidden_size, device=device)\n        self.relu = nn.ReLU()\n\n        inner = []\n        for _ in range(num_inner):\n            inner.extend([\n                nn.Linear(hidden_size, hidden_size, device=device),\n                nn.LayerNorm(hidden_size, device=device),\n                nn.ReLU(),\n            ])\n        if inner:\n            self.inner = nn.Sequential(*inner)\n        else:\n            self.inner = nn.Identity()\n\n        self.fc2 = nn.Linear(hidden_size, output_size, device=device)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc1(x)\n        x = self.norm(x)\n        x = self.relu(x)\n        x = self.inner(x)\n        x = self.fc2(x)\n        return x\n\n\nclass MLP2(nn.Module):\n    def __init__(self, input_size: int, hidden_size: int, output_size: int,\n                 num_inner: int = 0,\n                 pre_norm: bool = False, device: torch.device = None,\n                 upsample_factor: int = 1,\n                 upsample_rank: int = None,\n                 from_config: bool = False,\n                 **kwargs):\n        super().__init__()\n\n        self.pre_norm = nn.Sequential(\n            nn.LayerNorm(input_size),\n            nn.GELU(),\n        ) if pre_norm else nn.Identity()\n\n        self.upsample_factor = upsample_factor\n        sq_ups = upsample_factor ** 2\n\n        self._real_output_dim = output_size // sq_ups\n\n        # hidden_size *= upsample_factor\n        # output_size *= (upsample_factor ** 2)\n\n        self.fc1 = nn.Linear(input_size, hidden_size, device=device)\n\n        blocks = []\n        for _ in range(num_inner):\n            blocks.append(nn.Sequential(\n                nn.LayerNorm(hidden_size, device=device),\n                nn.GELU(),\n                nn.Linear(hidden_size, hidden_size, device=device),\n            ))\n        self.blocks = nn.ModuleList(blocks)\n\n        self.final = nn.Sequential(\n            nn.LayerNorm(hidden_size, device=device),\n            nn.GELU(),\n            nn.Linear(hidden_size, output_size, device=device),\n        )\n\n    def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:\n        x = self.pre_norm(x)\n        x = self.fc1(x)\n        for block in self.blocks:\n            x = x + block(x)\n        x = self.final(x)\n\n        if self.upsample_factor > 1:\n            if images is None:\n                raise ValueError(f'`images` cannot be `None` when the head\\'s `upsample_factor > 1`!')\n            if patch_size is None:\n                raise ValueError(f'`patch_size` cannot be `None` when the head\\'s `upsample_factor > 1`!')\n            h, w = tuple(d // patch_size for d in images.shape[-2:])\n            x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',\n                          h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,\n                          c=self._real_output_dim)\n\n        return x\n\n\nMLP_FACTORY = {\n    'v1': MLP,\n    'v2': MLP2,\n}\n\n\ndef strip_prefix(state: Dict[str, torch.Tensor], prefix: str):\n    state = {\n        k[len(prefix):]: v\n        for k, v in state.items()\n        if k.startswith(prefix)\n    }\n    return state\n\n\ndef get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):\n    state = strip_prefix(state, prefix)\n\n    weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'\n\n    if version == 'v1':\n        hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape\n        output_dim = state[f'fc2.{weight_suffix}'].shape[0]\n\n        for num_inner in range(1000):\n            k = f'inner.{num_inner}.0.weight'\n            if k not in state:\n                break\n    elif version == 'v2':\n        hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape\n        output_dim = state[f'final.2.{weight_suffix}'].shape[0]\n\n        for num_inner in range(1000):\n            k = f'blocks.{num_inner}.0.weight'\n            if k not in state:\n                break\n    else:\n        raise ValueError(f'Unsupported MLP version: {version}')\n\n    return input_dim, hidden_dim, output_dim, num_inner\n\n\ndef create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):\n    ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)\n\n    return ret\n\n\ndef create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):\n    state = strip_prefix(state, prefix)\n\n    input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)\n\n    ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)\n\n    if spectral_weights:\n        enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)\n\n    ret.load_state_dict(state)\n\n    if spectral_weights:\n        disable_spectral_reparam(ret)\n\n    return ret\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/adaptor_registry.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom argparse import Namespace\nfrom typing import Dict, Any\n\nimport torch\n\nfrom .adaptor_generic import GenericAdaptor, AdaptorBase\n\ndict_t = Dict[str, Any]\nstate_t = Dict[str, torch.Tensor]\n\n\nclass AdaptorRegistry:\n    def __init__(self):\n        self._registry = {}\n\n    def register_adaptor(self, name):\n        def decorator(factory_function):\n            if name in self._registry:\n                raise ValueError(f\"Model '{name}' already registered\")\n            self._registry[name] = factory_function\n            return factory_function\n        return decorator\n\n    def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:\n        if name not in self._registry:\n            return GenericAdaptor(main_config, adaptor_config, state)\n        return self._registry[name](main_config, adaptor_config, state)\n\n# Creating an instance of the registry\nadaptor_registry = AdaptorRegistry()\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/block.py",
    "content": "# Ultralytics YOLO 🚀, AGPL-3.0 license\n\"\"\"\nBlock modules\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import DropPath\n\nfrom .conv import Conv\n# from .transformer import TransformerBlock\n\n__all__ = ('C2f', 'Bottleneck',)\n\nclass C2f(nn.Module):\n    \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n\n    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None):  # ch_in, ch_out, number, shortcut, groups, expansion\n        super().__init__()\n        if drop_path is None:\n            drop_path = [0.0] * n\n\n        self.c = int(c2 * e)  # hidden channels\n        self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)\n        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))\n\n    def forward(self, x):\n        \"\"\"Forward pass through C2f layer.\"\"\"\n        y = list(self.cv1(x).chunk(2, 1))\n        y.extend(m(y[-1]) for m in self.m)\n        return self.cv2(torch.cat(y, 1))\n\n    def forward_split(self, x):\n        \"\"\"Forward pass using split() instead of chunk().\"\"\"\n        y = list(self.cv1(x).split((self.c, self.c), 1))\n        y.extend(m(y[-1]) for m in self.m)\n        return self.cv2(torch.cat(y, 1))\n\n\nclass Bottleneck(nn.Module):\n    \"\"\"Standard bottleneck.\"\"\"\n\n    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0):  # ch_in, ch_out, shortcut, groups, kernels, expand\n        super().__init__()\n        c_ = int(c2 * e)  # hidden channels\n        self.cv1 = Conv(c1, c_, k[0], 1)\n        self.cv2 = Conv(c_, c2, k[1], 1, g=g)\n        self.add = shortcut and c1 == c2\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        \"\"\"'forward()' applies the YOLOv5 FPN to input data.\"\"\"\n        return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/cls_token.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom typing import Optional\n\nimport torch\nfrom torch import nn\n\n\nclass ClsToken(nn.Module):\n    def __init__(self, ndim: int,\n                 num_tokens: int = 1,\n                 enabled: bool = True,\n                 register_multiple: Optional[int] = None,\n                 num_registers: Optional[int] = None,\n    ):\n        super().__init__()\n\n        self.ndim = ndim\n        self.enabled = enabled\n        self.num_registers = 0\n        self.num_tokens = num_tokens\n        if enabled:\n            if num_registers:\n                self.num_registers = num_registers\n            elif register_multiple:\n                self.num_registers = register_multiple - (num_tokens % register_multiple)\n\n            scale = ndim ** -0.5\n            self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)\n        else:\n            self.token = None\n\n        self.num_patches = self.num_tokens + self.num_registers\n\n    def disable(self):\n        self.token = None\n        self.enabled = False\n\n    def forward(self, x: torch.Tensor):\n        if self.token is None:\n            return x\n\n        token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)\n        x = torch.cat([\n            token,\n            x,\n        ], dim=1)\n\n        return x\n\n    def no_weight_decay(self):\n        return [\n            'token',\n        ]\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/common.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nfrom .radio_model import Resolution\n\n\n@dataclass\nclass RadioResource:\n    url: str\n    patch_size: int\n    max_resolution: int\n    preferred_resolution: Resolution\n    vitdet_num_windowed: Optional[int] = None\n    vitdet_num_global: Optional[int] = None\n\n\nRESOURCE_MAP = {\n    # RADIOv2.5\n    \"radio_v2.5-b\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=(768, 768),\n        vitdet_num_global=4,\n    ),\n    \"radio_v2.5-l\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=(768, 768),\n        vitdet_num_global=4,\n    ),\n    \"radio_v2.5-h\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=(768, 768),\n        vitdet_num_global=4,\n    ),\n    \"radio_v2.5-h-norm\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=(768, 768),\n        vitdet_num_global=4,\n    ),\n    \"radio_v2.5-g\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true\",\n        patch_size=14,\n        max_resolution=1792,\n        preferred_resolution=(896, 896),\n        vitdet_num_global=8,\n    ),\n    # RADIO\n    \"radio_v2.1\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=Resolution(432, 432),\n        vitdet_num_windowed=5,\n    ),\n    \"radio_v2\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=Resolution(432, 432),\n        vitdet_num_windowed=5,\n    ),\n    \"radio_v1\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true\",\n        patch_size=14,\n        max_resolution=1050,\n        preferred_resolution=Resolution(378, 378),\n    ),\n    # E-RADIO\n    \"e-radio_v2\": RadioResource(\n        \"https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=Resolution(512, 512),\n    ),\n    # C-RADIO\n    \"c-radio_v2.5-g\": RadioResource(\n        \"https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=(768, 768),\n        vitdet_num_global=8,\n    ),\n    \"c-radio_v3-l\": RadioResource(\n        # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L\n        # and accept the license terms.\n        \"https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true\",\n        patch_size=16,\n        max_resolution=2048,\n        preferred_resolution=Resolution(512, 512),\n    ),\n}\n\nDEFAULT_VERSION = \"radio_v2.5-h\"\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/conv.py",
    "content": "# Ultralytics YOLO 🚀, AGPL-3.0 license\n\"\"\"\nConvolution modules\n\"\"\"\n\nimport math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',\n           'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')\n\n\ndef autopad(k, p=None, d=1):  # kernel, padding, dilation\n    \"\"\"Pad to 'same' shape outputs.\"\"\"\n    if d > 1:\n        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size\n    if p is None:\n        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad\n    return p\n\n# Pavlo's implementation with switch to deploy\nclass Conv(nn.Module):\n    default_act = nn.SiLU()  # default activation\n\n    def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):\n        super().__init__()\n\n        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)\n        if 1:\n            self.bn = torch.nn.BatchNorm2d(b)\n            torch.nn.init.constant_(self.bn.weight, bn_weight_init)\n            torch.nn.init.constant_(self.bn.bias, 0)\n        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n\n    def forward(self,x):\n        x = self.conv(x)\n        x = self.bn(x)\n        x = self.act(x)\n        return x\n\n    @torch.no_grad()\n    def switch_to_deploy(self):\n        if not isinstance(self.bn, nn.Identity):\n            # return 1\n            c, bn = self.conv, self.bn\n            w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n            w = c.weight * w[:, None, None, None]\n            b = bn.bias - bn.running_mean * bn.weight / \\\n                (bn.running_var + bn.eps)**0.5\n            # m = torch.nn.Conv2d(w.size(1) * c.groups,\n            #                     w.size(0),\n            #                     w.shape[2:],\n            #                     stride=c.stride,\n            #                     padding=c.padding,\n            #                     dilation=c.dilation,\n            #                     groups=c.groups)\n            self.conv.weight.data.copy_(w)\n            self.conv.bias = nn.Parameter(b)\n            # self.conv.bias.data.copy_(b)\n            # self.conv = m.to(c.weight.device)\n            self.bn = nn.Identity()\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/dinov2_arch.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\n# Nvidia\n# NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking,\n# but also because Huggingface does a string replace of `gamma` to something else when loading the model state,\n# and this breaks loading of this model.\n\nfrom enum import Enum\nfrom functools import partial\nimport logging\nimport math\nimport os\nimport sys\nfrom typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union\nimport warnings\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import trunc_normal_\n\n_torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention')\n\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\ntry:\n    if XFORMERS_ENABLED:\n        from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind\n\n        XFORMERS_AVAILABLE = True\n    else:\n        raise ImportError\nexcept ImportError:\n    XFORMERS_AVAILABLE = False\n\n\ndef make_2tuple(x):\n    if isinstance(x, tuple):\n        assert len(x) == 2\n        return x\n\n    assert isinstance(x, int)\n    return (x, x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    2D image to patch embedding: (B,C,H,W) -> (B,N,D)\n\n    Args:\n        img_size: Image size.\n        patch_size: Patch token size.\n        in_chans: Number of input image channels.\n        embed_dim: Number of linear projection output channels.\n        norm_layer: Normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Union[int, Tuple[int, int]] = 224,\n        patch_size: Union[int, Tuple[int, int]] = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        norm_layer: Optional[Callable] = None,\n        flatten_embedding: bool = True,\n    ) -> None:\n        super().__init__()\n\n        image_HW = make_2tuple(img_size)\n        patch_HW = make_2tuple(patch_size)\n        patch_grid_size = (\n            image_HW[0] // patch_HW[0],\n            image_HW[1] // patch_HW[1],\n        )\n\n        self.img_size = image_HW\n        self.patch_size = patch_HW\n        self.patches_resolution = patch_grid_size\n        self.num_patches = patch_grid_size[0] * patch_grid_size[1]\n\n        self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.flatten_embedding = flatten_embedding\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        _, _, H, W = x.shape\n        patch_H, patch_W = self.patch_size\n\n        assert H % patch_H == 0, f\"Input image height {H} is not a multiple of patch height {patch_H}\"\n        assert W % patch_W == 0, f\"Input image width {W} is not a multiple of patch width: {patch_W}\"\n\n        x = self.proj(x)  # B C H W\n        H, W = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)  # B HW C\n        x = self.norm(x)\n        if not self.flatten_embedding:\n            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C\n        return x\n\n    def flops(self) -> float:\n        Ho, Wo = self.patches_resolution\n        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        if self.norm is not None:\n            flops += Ho * Wo * self.embed_dim\n        return flops\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        q, k, v = qkv[0], qkv[1], qkv[2]\n        if _torch_has_sdpa:\n            x = F.scaled_dot_product_attention(\n                q, k, v,\n                is_causal=False,\n                dropout_p=self.attn_drop.p if self.training else 0.,\n                scale=self.scale,\n            )\n        else:\n            q = q * self.scale\n            attn = q @ k.transpose(-2, -1)\n\n            attn = attn.softmax(dim=-1)\n            attn = self.attn_drop(attn)\n            x = attn @ v\n\n        x = x.transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:\n        if not XFORMERS_AVAILABLE:\n            if attn_bias is not None:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass SwiGLUFFN(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)\n        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x12 = self.w12(x)\n        x1, x2 = x12.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n        return self.w3(hidden)\n\n\nif not XFORMERS_AVAILABLE:\n    SwiGLU = SwiGLUFFN\n\n\nclass SwiGLUFFNFused(SwiGLU):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8\n        super().__init__(\n            in_features=in_features,\n            hidden_features=hidden_features,\n            out_features=out_features,\n            bias=bias,\n        )\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0:\n        random_tensor.div_(keep_prob)\n    output = x * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n\n\nclass LayerScale(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        init_values: Union[float, torch.Tensor] = 1e-5,\n        inplace: bool = False,\n    ) -> None:\n        super().__init__()\n        self.inplace = inplace\n        self.grandma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return x.mul_(self.grandma) if self.inplace else x * self.grandma\n\n    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):\n        # Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation\n        # of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either\n        # format\n        key_a = f'{prefix}gamma'\n        key_b = f'{prefix}grandma'\n        if key_a in state_dict:\n            gamma = state_dict[key_a]\n        elif key_b in state_dict:\n            gamma = state_dict[key_b]\n        else:\n            if strict:\n                raise KeyError(f\"Couldn't find the key {key_a} nor {key_b} in the state dict!\")\n            else:\n                missing_keys.append(key_a)\n                missing_keys.append(key_b)\n                unexpected_keys.extend(state_dict.keys())\n                gamma = None\n\n        if gamma is not None:\n            self.grandma.data.copy_(gamma)\n\n        # return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n    ) -> None:\n        super().__init__()\n        # print(f\"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}\")\n        self.norm1 = norm_layer(dim)\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        def attn_residual_func(x: torch.Tensor) -> torch.Tensor:\n            return self.ls1(self.attn(self.norm1(x)))\n\n        def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:\n            return self.ls2(self.mlp(self.norm2(x)))\n\n        if self.training and self.sample_drop_ratio > 0.1:\n            # the overhead is compensated only for a drop path rate larger than 0.1\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n        elif self.training and self.sample_drop_ratio > 0.0:\n            x = x + self.drop_path1(attn_residual_func(x))\n            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2\n        else:\n            x = x + attn_residual_func(x)\n            x = x + ffn_residual_func(x)\n        return x\n\n\nclass NestedTensorBlock(Block):\n    def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:\n        \"\"\"\n        x_list contains a list of tensors to nest together and run\n        \"\"\"\n        assert isinstance(self.attn, MemEffAttention)\n\n        if self.training and self.sample_drop_ratio > 0.0:\n\n            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:\n                return self.attn(self.norm1(x), attn_bias=attn_bias)\n\n            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:\n                return self.mlp(self.norm2(x))\n\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None,\n            )\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None,\n            )\n            return x_list\n        else:\n\n            def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:\n                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))\n\n            def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:\n                return self.ls2(self.mlp(self.norm2(x)))\n\n            attn_bias, x = get_attn_bias_and_cat(x_list)\n            x = x + attn_residual_func(x, attn_bias=attn_bias)\n            x = x + ffn_residual_func(x)\n            return attn_bias.split(x)\n\n    def forward(self, x_or_x_list):\n        if isinstance(x_or_x_list, torch.Tensor):\n            return super().forward(x_or_x_list)\n        elif isinstance(x_or_x_list, list):\n            if not XFORMERS_AVAILABLE:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return self.forward_nested(x_or_x_list)\n        else:\n            raise AssertionError\n\n\ndef drop_add_residual_stochastic_depth(\n    x: torch.Tensor,\n    residual_func: Callable[[torch.Tensor], torch.Tensor],\n    sample_drop_ratio: float = 0.0,\n) -> torch.Tensor:\n    # 1) extract subset using permutation\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    x_subset = x[brange]\n\n    # 2) apply residual_func to get residual\n    residual = residual_func(x_subset)\n\n    x_flat = x.flatten(1)\n    residual = residual.flatten(1)\n\n    residual_scale_factor = b / sample_subset_size\n\n    # 3) add the residual\n    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    return x_plus_residual.view_as(x)\n\n\ndef get_branges_scales(x, sample_drop_ratio=0.0):\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    residual_scale_factor = b / sample_subset_size\n    return brange, residual_scale_factor\n\n\ndef add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):\n    if scaling_vector is None:\n        x_flat = x.flatten(1)\n        residual = residual.flatten(1)\n        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    else:\n        x_plus_residual = scaled_index_add(\n            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor\n        )\n    return x_plus_residual\n\n\nattn_bias_cache: Dict[Tuple, Any] = {}\n\n\ndef get_attn_bias_and_cat(x_list, branges=None):\n    \"\"\"\n    this will perform the index select, cat the tensors, and provide the attn_bias from cache\n    \"\"\"\n    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]\n    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))\n    if all_shapes not in attn_bias_cache.keys():\n        seqlens = []\n        for b, x in zip(batch_sizes, x_list):\n            for _ in range(b):\n                seqlens.append(x.shape[1])\n        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)\n        attn_bias._batch_sizes = batch_sizes\n        attn_bias_cache[all_shapes] = attn_bias\n\n    if branges is not None:\n        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])\n    else:\n        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)\n        cat_tensors = torch.cat(tensors_bs1, dim=1)\n\n    return attn_bias_cache[all_shapes], cat_tensors\n\n\ndef drop_add_residual_stochastic_depth_list(\n    x_list: List[torch.Tensor],\n    residual_func: Callable[[torch.Tensor, Any], torch.Tensor],\n    sample_drop_ratio: float = 0.0,\n    scaling_vector=None,\n) -> torch.Tensor:\n    # 1) generate random set of indices for dropping samples in the batch\n    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]\n    branges = [s[0] for s in branges_scales]\n    residual_scale_factors = [s[1] for s in branges_scales]\n\n    # 2) get attention bias and index+concat the tensors\n    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)\n\n    # 3) apply residual_func to get residual, and split the result\n    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore\n\n    outputs = []\n    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):\n        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))\n    return outputs\n\n\ndef named_apply(fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass BlockChunk(nn.ModuleList):\n    def forward(self, x):\n        for b in self:\n            x = b(x)\n        return x\n\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        ffn_bias=True,\n        proj_bias=True,\n        drop_path_rate=0.0,\n        drop_path_uniform=False,\n        init_values=None,  # for layerscale: None or 0 => no layerscale\n        embed_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        block_fn=Block,\n        ffn_layer=\"mlp\",\n        block_chunks=1,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            proj_bias (bool): enable bias for proj in attn if True\n            ffn_bias (bool): enable bias for ffn if True\n            drop_path_rate (float): stochastic depth rate\n            drop_path_uniform (bool): apply uniform drop rate across blocks\n            weight_init (str): weight init scheme\n            init_values (float): layer-scale init values\n            embed_layer (nn.Module): patch embedding layer\n            act_layer (nn.Module): MLP activation layer\n            block_fn (nn.Module): transformer block class\n            ffn_layer (str): \"mlp\", \"swiglu\", \"swiglufused\" or \"identity\"\n            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap\n            num_register_tokens: (int) number of extra cls tokens (so-called \"registers\")\n            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings\n            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings\n        \"\"\"\n        super().__init__()\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 1\n        self.n_blocks = depth\n        self.num_heads = num_heads\n        self.patch_size = patch_size\n        self.num_register_tokens = num_register_tokens\n        self.interpolate_antialias = interpolate_antialias\n        self.interpolate_offset = interpolate_offset\n\n        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))\n        assert num_register_tokens >= 0\n        self.register_tokens = (\n            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None\n        )\n\n        if drop_path_uniform is True:\n            dpr = [drop_path_rate] * depth\n        else:\n            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n\n        if ffn_layer == \"mlp\":\n            ffn_layer = Mlp\n        elif ffn_layer == \"swiglufused\" or ffn_layer == \"swiglu\":\n            ffn_layer = SwiGLUFFNFused\n        elif ffn_layer == \"identity\":\n            def f(*args, **kwargs):\n                return nn.Identity()\n\n            ffn_layer = f\n        else:\n            raise NotImplementedError\n\n        blocks_list = [\n            block_fn(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                proj_bias=proj_bias,\n                ffn_bias=ffn_bias,\n                drop_path=dpr[i],\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                ffn_layer=ffn_layer,\n                init_values=init_values,\n            )\n            for i in range(depth)\n        ]\n        if block_chunks > 0:\n            self.chunked_blocks = True\n            chunked_blocks = []\n            chunksize = depth // block_chunks\n            for i in range(0, depth, chunksize):\n                # this is to keep the block index consistent if we chunk the block list\n                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])\n            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])\n        else:\n            self.chunked_blocks = False\n            self.blocks = nn.ModuleList(blocks_list)\n\n        self.norm = norm_layer(embed_dim)\n        self.head = nn.Identity()\n\n        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))\n\n    def interpolate_pos_encoding(self, x, w, h):\n        previous_dtype = x.dtype\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        pos_embed = self.pos_embed.float()\n        class_pos_embed = pos_embed[:, 0]\n        patch_pos_embed = pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        M = int(math.sqrt(N))  # Recover the number of patches in each dimension\n        assert N == M * M\n        kwargs = {}\n        if self.interpolate_offset:\n            # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8\n            # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors\n            sx = float(w0 + self.interpolate_offset) / M\n            sy = float(h0 + self.interpolate_offset) / M\n            kwargs[\"scale_factor\"] = (sx, sy)\n        else:\n            # Simply specify an output size instead of a scale factor\n            kwargs[\"size\"] = (w0, h0)\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),\n            mode=\"bicubic\",\n            antialias=self.interpolate_antialias,\n            **kwargs,\n        )\n        assert (w0, h0) == patch_pos_embed.shape[-2:]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)\n\n    def prepare_tokens_with_masks(self, x, masks=None):\n        B, nc, w, h = x.shape\n        x = self.patch_embed(x)\n        if masks is not None:\n            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)\n\n        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        if self.register_tokens is not None:\n            x = torch.cat(\n                (\n                    x[:, :1],\n                    self.register_tokens.expand(x.shape[0], -1, -1),\n                    x[:, 1:],\n                ),\n                dim=1,\n            )\n\n        return x\n\n    def forward_features_list(self, x_list, masks_list):\n        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]\n        for blk in self.blocks:\n            x = blk(x)\n\n        all_x = x\n        output = []\n        for x, masks in zip(all_x, masks_list):\n            x_norm = self.norm(x)\n            output.append(\n                {\n                    \"x_norm_clstoken\": x_norm[:, 0],\n                    \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n                    \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n                    \"x_prenorm\": x,\n                    \"masks\": masks,\n                }\n            )\n        return output\n\n    def forward_features(self, x, masks=None):\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n\n        for blk in self.blocks:\n            x = blk(x)\n\n        x_norm = self.norm(x)\n        return {\n            \"x_norm_clstoken\": x_norm[:, 0],\n            \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n            \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n            \"x_prenorm\": x,\n            \"masks\": masks,\n        }\n\n    def _get_intermediate_layers_not_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        # If n is an int, take the n last blocks. If it's a list, take them\n        output, total_block_len = [], len(self.blocks)\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in blocks_to_take:\n                output.append(x)\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def _get_intermediate_layers_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        output, i, total_block_len = [], 0, len(self.blocks[-1])\n        # If n is an int, take the n last blocks. If it's a list, take them\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for block_chunk in self.blocks:\n            for blk in block_chunk[i:]:  # Passing the nn.Identity()\n                x = blk(x)\n                if i in blocks_to_take:\n                    output.append(x)\n                i += 1\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def get_intermediate_layers(\n        self,\n        x: torch.Tensor,\n        n: Union[int, Sequence] = 1,  # Layers or n last layers to take\n        reshape: bool = False,\n        return_class_token: bool = False,\n        norm=True,\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:\n        if self.chunked_blocks:\n            outputs = self._get_intermediate_layers_chunked(x, n)\n        else:\n            outputs = self._get_intermediate_layers_not_chunked(x, n)\n        if norm:\n            outputs = [self.norm(out) for out in outputs]\n        class_tokens = [out[:, 0] for out in outputs]\n        outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]\n        if reshape:\n            B, _, w, h = x.shape\n            outputs = [\n                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()\n                for out in outputs\n            ]\n        if return_class_token:\n            return tuple(zip(outputs, class_tokens))\n        return tuple(outputs)\n\n    def forward(self, *args, is_training=False, **kwargs):\n        ret = self.forward_features(*args, **kwargs)\n        if is_training:\n            return ret\n        else:\n            return self.head(ret[\"x_norm_clstoken\"])\n\n\ndef vit_small(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_base(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_large(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\ndef vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):\n    \"\"\"\n    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64\n    \"\"\"\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1536,\n        depth=40,\n        num_heads=24,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\nclass Weights(Enum):\n    LVD142M = \"LVD142M\"\n\n\ndef _make_dinov2_model(\n    *,\n    arch_name: str = \"vit_large\",\n    img_size: int = 518,\n    patch_size: int = 14,\n    init_values: float = 1.0,\n    ffn_layer: str = \"mlp\",\n    block_chunks: int = 0,\n    num_register_tokens: int = 0,\n    interpolate_antialias: bool = False,\n    interpolate_offset: float = 0.1,\n    weights: Union[Weights, str] = Weights.LVD142M,\n    **kwargs,\n):\n    if isinstance(weights, str):\n        try:\n            weights = Weights[weights]\n        except KeyError:\n            raise AssertionError(f\"Unsupported weights: {weights}\")\n\n    vit_kwargs = dict(\n        img_size=img_size,\n        patch_size=patch_size,\n        init_values=init_values,\n        ffn_layer=ffn_layer,\n        block_chunks=block_chunks,\n        num_register_tokens=num_register_tokens,\n        interpolate_antialias=interpolate_antialias,\n        interpolate_offset=interpolate_offset,\n    )\n    vit_kwargs.update(**kwargs)\n    model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs)\n\n    return model\n\n\ndef dinov2_vits14(**kwargs):\n    \"\"\"\n    DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(arch_name=\"vit_small\", **kwargs)\n\n\ndef dinov2_vitb14(**kwargs):\n    \"\"\"\n    DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(arch_name=\"vit_base\", **kwargs)\n\n\ndef dinov2_vitl14(**kwargs):\n    \"\"\"\n    DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(arch_name=\"vit_large\", **kwargs)\n\n\ndef dinov2_vitg14(**kwargs):\n    \"\"\"\n    DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(\n        arch_name=\"vit_giant2\",\n        ffn_layer=\"swiglufused\",\n        **kwargs,\n    )\n\n\ndef dinov2_vits14_reg(**kwargs):\n    \"\"\"\n    DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(\n        arch_name=\"vit_small\",\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n\n\ndef dinov2_vitb14_reg(**kwargs):\n    \"\"\"\n    DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(\n        arch_name=\"vit_base\",\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n\n\ndef dinov2_vitl14_reg(**kwargs):\n    \"\"\"\n    DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(\n        arch_name=\"vit_large\",\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n\n\ndef dinov2_vitg14_reg(**kwargs):\n    \"\"\"\n    DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.\n    \"\"\"\n    return _make_dinov2_model(\n        arch_name=\"vit_giant2\",\n        ffn_layer=\"swiglufused\",\n        num_register_tokens=4,\n        interpolate_antialias=True,\n        interpolate_offset=0.0,\n        **kwargs,\n    )\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/dual_hybrid_vit.py",
    "content": "from logging import getLogger\nfrom typing import Tuple\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models import register_model\nfrom timm.models import vision_transformer as tvit\nfrom timm.models import convnext as tconv\n\nfrom einops import rearrange\n\nfrom . import extra_timm_models as et\n\n\nclass Fuser(nn.Module):\n    def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):\n        super().__init__()\n        self.gated = gated\n\n        mid_dim = max(src_dim, tgt_dim) * 2\n\n        self.fwd = nn.Sequential(\n            nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),\n            nn.GELU(),\n            nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),\n        )\n\n    def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:\n        if src.ndim == 3:\n            shape = tgt.shape[-2:]\n        else:\n            shape = src.shape[-2:]\n\n        nd = shape[0] * shape[1]\n\n        if src.ndim == 3:\n            src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)\n\n        if tgt.ndim == 3:\n            tgt_pre = tgt[:, :-nd]\n            tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)\n        else:\n            tgt_pre = None\n\n        pred = self.fwd(src)\n\n        if self.gated:\n            g, pred = torch.chunk(pred, 2, dim=1)\n\n            g = F.sigmoid(g)\n\n            pred = g * pred\n\n        tgt = tgt + pred\n\n        if tgt_pre is not None:\n            tgt = rearrange(tgt, 'b c h w -> b (h w) c')\n            tgt = torch.cat([tgt_pre, tgt], dim=1)\n\n        return tgt\n\n\nclass AttnDownsample(nn.Module):\n    def __init__(self, dim: int, window_size: int, num_heads: int = 16):\n        super().__init__()\n        self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)\n        self.kv = nn.Linear(dim, dim * 2)\n        self.proj = nn.Linear(dim, dim)\n        self.window_size = window_size\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n\n    def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:\n        ntok = twod_shape[0] * twod_shape[1]\n        x_pre = x[:, :-ntok]\n\n        B = x.shape[0]\n        ds_hw = tuple(s // self.window_size for s in twod_shape)\n\n        x_spat = rearrange(\n            x[:, -ntok:],\n            'b (h d1 w d2) c -> (b h w) (d1 d2) c',\n            h=ds_hw[0], w=ds_hw[1],\n            d1=self.window_size, d2=self.window_size,\n        )\n\n        B, N, C = x_spat.shape\n\n        k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)\n\n        q = (self.q * self.scale).expand(B, -1, -1, -1)\n        attn = q @ k.transpose(-2, -1)\n        attn = F.softmax(attn, dim=-1)\n        x = attn @ v\n\n        x = x.transpose(1, 2).reshape(B, C)\n        x = self.proj(x)\n\n        x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])\n\n        x = torch.cat([x_pre, x], dim=1)\n        return x\n\n\nclass HybridModel(nn.Module):\n    def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,\n                 concatenate: bool = False, **kwargs):\n        super().__init__()\n        self.conv = conv\n        self.vit = vit\n        self.concatenate = concatenate\n\n        conv.stages = nn.ModuleList(conv.stages)\n        vit.blocks = nn.ModuleList(vit.blocks)\n\n        self._half_vit_idx = len(vit.blocks) // 2 + 1\n\n        self._half_conv_idx = None\n        x = torch.empty(1, 3, 256, 256)\n        x = self.conv.stem(x)\n        for i in range(len(conv.stages)):\n            x = conv.stages[i](x)\n            if self._half_conv_idx is None and x.shape[-2:] == (16, 16):\n                self._half_conv_idx = i + 1\n                half_conv_dim = x.shape[1]\n            final_conv_dim = x.shape[1]\n\n        self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)\n        self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)\n        self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)\n\n        embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)\n        if not concatenate:\n            self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)\n        self.final_block = tvit.Block(embed_dim, num_heads=16)\n\n        self.embed_dim = embed_dim\n\n    @property\n    def patch_size(self):\n        return 32\n\n    @property\n    def no_fsdp_wrap_types(self):\n        return {tvit.VisionTransformer, tconv.ConvNeXt}\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.forward_features(x)\n\n    def forward_features(self, x: torch.Tensor) -> torch.Tensor:\n        y_vit = self.vit.patch_generator(x)\n\n        for i in range(self._half_vit_idx):\n            y_vit = self.vit.blocks[i](y_vit)\n\n        y_conv = self.conv.stem(x)\n        for i in range(self._half_conv_idx):\n            y_conv = self.conv.stages[i](y_conv)\n\n        y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)\n\n        y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])\n\n        for i in range(self._half_vit_idx, len(self.vit.blocks)):\n            y_vit = self.vit.blocks[i](y_vit)\n\n        for i in range(self._half_conv_idx, len(self.conv.stages)):\n            y_conv = self.conv.stages[i](y_conv)\n\n        if self.concatenate:\n            y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')\n            # Average pool across the board, and replicate for each cls/register token\n            conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)\n            y_conv = torch.cat([conv_summary, y_conv], dim=1)\n            y = torch.cat([y_vit, y_conv], dim=2)\n        else:\n            y = self.final_fuse(y_conv, y_vit)\n        y = self.final_block(y)\n\n        summary = y[:, :self.vit.patch_generator.num_cls_tokens]\n        features = y[:, self.vit.patch_generator.num_cls_patches:]\n\n        return summary, features\n\n\n@register_model\ndef hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):\n    cfg = dict(num_classes=0, **kwargs)\n    conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)\n    vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)\n\n    return HybridModel(vit, conv, pretrained, concatenate=concatenate)\n\n\n@register_model\ndef hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):\n    cfg = dict(num_classes=0, **kwargs)\n    conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)\n    vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)\n\n    return HybridModel(vit, conv, pretrained, concatenate=concatenate)\n\n\n@register_model\ndef hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):\n    cfg = dict(num_classes=0, **kwargs)\n    conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)\n    vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)\n\n    return HybridModel(vit, conv, pretrained, concatenate=concatenate)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_cpe_support.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom typing import List, Optional, Set, Tuple, Union\nfrom types import MethodType\n\nimport torch\nfrom torch import nn\n\nfrom timm.models import VisionTransformer, checkpoint_seq\nfrom timm.models.vision_transformer import Attention, Block\n\nfrom .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer\n\nfrom .extra_models import DinoWrapper\nfrom .vit_patch_generator import ViTPatchGenerator\nfrom .forward_intermediates import forward_intermediates\nfrom .dual_hybrid_vit import HybridModel\nfrom flash_attn import flash_attn_varlen_func\n\n\ndef _attn_forward_pack(self: Attention, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:\n    N, C = x.shape\n    qkv = self.qkv(x).reshape(N, 3, self.num_heads, self.head_dim).permute(1, 0, 2, 3)\n    q, k, v = qkv.unbind(0)\n    q, k = self.q_norm(q), self.k_norm(k)\n    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()\n\n    x = flash_attn_varlen_func(\n        q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen\n    ).reshape(N, -1)\n\n    x = self.proj(x)\n    x = self.proj_drop(x)\n    return x\n\ndef _block_forward_pack(self: Block, x: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:\n    x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), cu_seqlens)))\n    x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))\n    return x\n\ndef _forward_cpe_pack(self: VisionTransformer, images: List[torch.Tensor]) -> torch.Tensor:\n    device = images[0].device\n    x = []\n    seqlens = []\n    for image in images:\n        # image: [1, c, H, W] -> x: [n_cls+h*w, D], h=H/p and w=W/p\n        _image = self.patch_generator(image).squeeze(0)\n        x.append(_image)\n        seqlens.append(_image.shape[0])\n    \n    x = torch.cat(x, dim=0)\n    seqlens = torch.tensor(seqlens, device=device, dtype=torch.int)\n    \n    cu_seqlens = torch.cat([\n        torch.tensor([0], device=device, dtype=torch.int32), \n        torch.cumsum(seqlens, dim=0, dtype=torch.int32)\n    ])\n    if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():\n        for block in self.blocks:\n            x = checkpoint_seq(block, x, cu_seqlens)\n    else:\n        for block in self.blocks:\n            x = block(x, cu_seqlens)\n    x = self.norm(x)\n    return x, cu_seqlens\n\ndef _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:\n    x = self.patch_generator(x)\n    if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():\n        x = checkpoint_seq(self.blocks, x)\n    else:\n        x = self.blocks(x)\n    x = self.norm(x)\n    return x\n\n\ndef _take_indices(\n        num_blocks: int,\n        n: Optional[Union[int, List[int], Tuple[int]]],\n) -> Tuple[Set[int], int]:\n    if isinstance(n, int):\n        assert n >= 0\n        take_indices = {x for x in range(num_blocks - n, num_blocks)}\n    else:\n        take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}\n    return take_indices, max(take_indices)\n\n\ndef _forward_intermediates_cpe(\n        self,\n        x: torch.Tensor,\n        norm: bool = False,\n        **kwargs,\n) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:\n    return forward_intermediates(\n        self,\n        patch_extractor=self.patch_generator,\n        num_summary_tokens=self.patch_generator.num_skip,\n        num_cls_tokens=self.patch_generator.num_cls_tokens,\n        norm=self.norm if norm else lambda y: y,\n        x=x,\n        **kwargs,\n    )\n\n\ndef _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor:\n    y = _forward_cpe(self.inner, x)\n\n    return y[:, 0], y[:, self.num_summary_tokens:]\n\n\ndef _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):\n    return _forward_intermediates_cpe(self.inner, *args, **kwargs)\n\n\ndef _enable_cpe_for_timm_vit(model: VisionTransformer,\n                             max_img_size: Union[int, Tuple[int, int]] = 1024,\n                             num_cls_tokens: int = 1,\n                             pos_dropout: float = 0.1,\n                             register_multiple: int = Optional[None],\n                             num_registers: int = Optional[None],\n                             support_packing: bool = False,\n):\n    if not isinstance(model, VisionTransformer):\n        raise ValueError(\"CPE only support for VisionTransformer models!\")\n\n    patch_size = model.patch_embed.patch_size[0]\n    embed_dim = model.embed_dim\n    input_dims = model.patch_embed.img_size\n    normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)\n    cls_token = model.cls_token is not None\n\n    max_img_size = int(round(max_img_size / patch_size) * patch_size)\n\n    patch_generator = ViTPatchGenerator(\n        patch_size=patch_size,\n        embed_dim=embed_dim,\n        input_dims=input_dims,\n        normalize_patches=normalize_patches,\n        cls_token=cls_token,\n        max_input_dims=max_img_size,\n        pos_dropout=pos_dropout,\n        num_cls_tokens=num_cls_tokens,\n        register_multiple=register_multiple,\n        num_registers=num_registers,\n    )\n\n    model.patch_generator = patch_generator\n    model.patch_embed = None\n    model.cls_token = None\n    model.pos_embed = None\n    model.pos_drop = None\n    model.patch_size = patch_size\n    model.num_cls_tokens = num_cls_tokens\n    model.num_registers = patch_generator.num_registers\n\n    model.forward_features = MethodType(_forward_cpe, model)\n    model.forward_intermediates = MethodType(_forward_intermediates_cpe, model)\n    if support_packing:\n        model.forward_features = MethodType(_forward_cpe_pack, model)\n        for block in model.blocks:\n            block.forward = MethodType(_block_forward_pack, block)\n            block.attn.forward = MethodType(_attn_forward_pack, block.attn)\n\n\ndef _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,\n                                max_img_size: Union[int, Tuple[int, int]] = 1024,\n                                num_cls_tokens: int = 1,\n                                pos_dropout: float = 0.1,\n                                register_multiple: int = Optional[None],\n                                num_registers: int = Optional[None],\n):\n    patch_size = model.patch_size\n    embed_dim = model.embed_dim\n    input_dims = model.inner.patch_embed.patches_resolution\n    normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity)\n    cls_token = True\n\n    max_img_size = int(round(max_img_size / patch_size) * patch_size)\n\n    patch_generator = ViTPatchGenerator(\n        patch_size=patch_size,\n        embed_dim=embed_dim,\n        input_dims=input_dims,\n        normalize_patches=normalize_patches,\n        cls_token=cls_token,\n        max_input_dims=max_img_size,\n        pos_dropout=pos_dropout,\n        num_cls_tokens=num_cls_tokens,\n        register_multiple=register_multiple,\n        num_registers=num_registers,\n        patch_bias=True,\n    )\n\n    inner = model.inner\n    inner.patch_generator = patch_generator\n    inner.patch_embed = None\n    inner.cls_token = None\n    inner.pos_embed = None\n    inner.register_tokens = None\n    inner.patch_size = patch_size\n\n    model.forward_features = MethodType(_forward_cpe_dinov2, model)\n    model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model)\n\n\ndef enable_cpe(model: nn.Module,\n               *args,\n               **kwargs,\n):\n    if isinstance(model, VisionTransformer):\n        _enable_cpe_for_timm_vit(model, *args, **kwargs)\n    elif isinstance(model, DinoWrapper):\n        _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)\n    elif isinstance(model, HybridModel):\n        _enable_cpe_for_timm_vit(model.vit, *args, **kwargs)\n    else:\n        raise ValueError(f'CPE not supported for this model type: {type(model)}')\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_damp.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom logging import getLogger\nimport math\nimport os\nfrom typing import Dict, List, Optional, Union, Tuple\nfrom types import MethodType\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils import parametrize\n\n\n# For now, don't do anything\nclass DAMP(nn.Identity):\n    def __init__(self, std: float):\n        super().__init__()\n        self.std = std\n\n\ndef enable_damp(model: nn.Module, std: float):\n    if isinstance(model, (list, tuple)):\n        for m in model:\n            enable_damp(m, std)\n        return\n\n    for name, module in model.named_modules():\n        if isinstance(module, nn.Linear):\n            parametrize.register_parametrization(module, 'weight', DAMP(std))\n\n\ndef configure_damp_from_args(model: nn.Module, args):\n    damp = getattr(args, 'damp', None)\n    if damp:\n        enable_damp(model, damp)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/enable_spectral_reparam.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom logging import getLogger\nimport math\nimport os\nfrom typing import Dict, List, Optional, Union, Tuple\nfrom types import MethodType\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.nn.utils import parametrize\nfrom torch.nn.utils.parametrizations import _SpectralNorm\n\nfrom timm.models.vision_transformer import Attention, Mlp\n\n_EPS = 1e-5\n\n\nclass _SNReweight(_SpectralNorm):\n    def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):\n        super().__init__(weight, *args, **kwargs)\n\n        self.alpha = alpha\n        self.version = version\n        self.register_buffer('_sn_version', torch.tensor(version))\n\n        if init_norm_to_current:\n            # This will set the numerator to match the denominator, which should preserve the original values\n            init_scale = self._get_sigma(weight, n_power_iterations=20).item()\n        else:\n            init_scale = 1.0\n\n        if version == 1:\n            init_value = init_scale\n        elif version == 2:\n            t = init_scale - alpha\n            if t < _EPS:\n                getLogger(\"spectral_reparam\").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')\n                t = _EPS\n\n            init_value = math.log(math.exp(t) - 1)\n        else:\n            raise ValueError(f'Unsupported version: {version}')\n\n        # Make 2D so that weight decay gets applied\n        self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))\n\n    # Re-implementing this because we need to make division by sigma safe\n    def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor:\n        if not n_power_iterations:\n            n_power_iterations = self.n_power_iterations\n        if weight.ndim == 1:\n            # Faster and more exact path, no need to approximate anything\n            sigma = weight.norm()\n        else:\n            weight_mat = self._reshape_weight_to_matrix(weight)\n            if self.training:\n                self._power_method(weight_mat, n_power_iterations)\n            # See above on why we need to clone\n            u = self._u.clone(memory_format=torch.contiguous_format)\n            v = self._v.clone(memory_format=torch.contiguous_format)\n            # The proper way of computing this should be through F.bilinear, but\n            # it seems to have some efficiency issues:\n            # https://github.com/pytorch/pytorch/issues/58093\n            sigma = torch.dot(u, torch.mv(weight_mat, v))\n\n        return sigma + self.eps\n\n    def forward(self, weight: torch.Tensor, *args, **kwargs):\n        dtype = weight.dtype\n        sigma = self._get_sigma(weight, *args, **kwargs)\n\n        if self.version == 1:\n            scale = self.scale\n        elif self.version == 2:\n            scale = F.softplus(self.scale) + self.alpha\n        else:\n            raise ValueError(f'Unsupported version: {self.version}')\n\n        scale = scale.float() / sigma.float()\n\n        y = weight * scale\n\n        if dtype in (torch.float16, torch.bfloat16):\n            y = y.to(dtype)\n        return y\n\n    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):\n        version_key = f'{prefix}_sn_version'\n        if version_key not in state_dict:\n            self.version = 1\n            state_dict[version_key] = torch.tensor(1)\n        return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)\n\n\nclass _ChunkedSNReweight(nn.Module):\n    def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs):\n        super().__init__()\n\n        self.num_chunks = num_chunks\n        parts = weight.split(weight.shape[0] // num_chunks, dim=0)\n\n        self.parts = nn.ModuleList([\n            _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs)\n            for p in parts\n        ])\n\n    def forward(self, weight: torch.Tensor, *args, **kwargs):\n        parts = weight.split(weight.shape[0] // self.num_chunks, dim=0)\n\n        parts = [\n            fn(p)\n            for fn, p in zip(self.parts, parts)\n        ]\n\n        return torch.cat(parts, dim=0)\n\n\nclass _AttnSNReweight(_ChunkedSNReweight):\n    def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):\n        super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs)\n\n        if not renorm_values:\n            self.parts[2] = nn.Identity()\n\n\ndef enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],\n                            n_power_iterations: int = 1,\n                            eps: float = 1e-6,\n                            init_norm_to_current: bool = False,\n                            renorm_values: bool = True,\n                            renorm_mlp: bool = True,\n                            state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):\n    if isinstance(model, (list, tuple)):\n        for i, sub in enumerate(model):\n            sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance\n            enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps,\n                                    init_norm_to_current=init_norm_to_current, renorm_values=renorm_values,\n                                    renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd)\n        return\n\n    print('Enabling spectral reparametrization')\n    args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current)\n    visited_prefixes = set()\n\n    def is_guidance_parametrized(name: str):\n        if state_dict_guidance is None:\n            return True\n\n        p_name = f'{name}.parametrizations'\n        is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))\n        return is_prm\n\n    def parametrize_linear(linear: nn.Linear):\n        parametrize.register_parametrization(\n            linear,\n            'weight',\n            _SNReweight(linear.weight, **args)\n        )\n\n    for name, mod in model.named_modules():\n        pref = '.'.join(name.split('.')[:-1])\n        if pref in visited_prefixes:\n            continue\n\n        if isinstance(mod, Attention) or name.endswith('.attn'):\n            if is_guidance_parametrized(f'{name}.qkv'):\n                parametrize.register_parametrization(\n                    mod.qkv,\n                    'weight',\n                    _AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args),\n                )\n            if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'):\n                parametrize_linear(mod.proj)\n            visited_prefixes.add(name)\n        elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'):\n            if is_guidance_parametrized(f'{name}.w12'):\n                parametrize.register_parametrization(\n                    mod.w12,\n                    'weight',\n                    _ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args),\n                )\n            if is_guidance_parametrized(f'{name}.w3'):\n                parametrize_linear(mod.w3)\n            visited_prefixes.add(name)\n        elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name):\n            parametrize_linear(mod)\n\n\ndef configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):\n    spectral_reparam = getattr(args, 'spectral_reparam', False)\n    if isinstance(spectral_reparam, bool) and spectral_reparam:\n        enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance)\n    elif isinstance(spectral_reparam, dict):\n        enable_spectral_reparam(\n            model,\n            n_power_iterations=spectral_reparam.get('n_power_iterations', 1),\n            eps=spectral_reparam.get('eps', 1e-12),\n            init_norm_to_current=True,\n            state_dict_guidance=state_dict_guidance,\n        )\n\n\ndef disable_spectral_reparam(model: nn.Module):\n    print('Disabling spectral reparametrization')\n    for name, mod in model.named_modules():\n        if parametrize.is_parametrized(mod):\n            parametrize.remove_parametrizations(mod, 'weight')\n            pass\n\n\n\nif __name__ == '__main__':\n    import argparse\n    from . import radio_model as create_model\n\n    parser = argparse.ArgumentParser(description='Remove parametrization from state dict')\n    parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')\n    parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')\n    parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')\n    parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')\n\n    args = parser.parse_args()\n\n    if not args.output:\n        chk_dir, chk_name = os.path.split(args.checkpoint)\n        args.output = os.path.join(chk_dir, f'clean_{chk_name}')\n        print(f'Set output to \"{args.output}\"')\n\n    chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)\n\n    model = create_model.create_model_from_args(chk['args'])\n\n    key = 'base_model.'\n    mod_state = dict()\n    extra_state = dict()\n    for k, v in chk['state_dict'].items():\n        if k.startswith(key):\n            mod_state[k[len(key):]] = v\n        else:\n            extra_state[k] = v\n\n    chk_load_info = model.load_state_dict(mod_state, strict=args.strict)\n    if chk_load_info.unexpected_keys or chk_load_info.missing_keys:\n        print(chk_load_info)\n\n    if chk['args'].spectral_reparam:\n        disable_spectral_reparam(model)\n\n    if hasattr(chk['args'], 'dtype'):\n        model.to(dtype=chk['args'].dtype)\n\n    mod_state = model.state_dict()\n    final_state = dict()\n    final_state.update({f'{key}{k}': v for k, v in mod_state.items()})\n    final_state.update(extra_state)\n\n    chk['state_dict'] = final_state\n    chk['args'].spectral_reparam = False\n\n    if args.release:\n        chk = {\n            'arch': chk['arch'],\n            'epoch': chk['epoch'],\n            'state_dict': chk['state_dict'],\n            'args': chk['args'],\n        }\n\n    torch.save(chk, args.output)\n    pass\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/eradio_model.py",
    "content": "#!/usr/bin/env python3\n\n# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n# E-RADIO model from\n# Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. \"AM-RADIO: Agglomerative Model--Reduce All Domains Into One.\" arXiv preprint arXiv:2312.06709 (2023).\n\n# based on FasterViT, Swin Transformer, YOLOv8\n\n# FasterViT:\n# Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. \"FasterViT: Fast Vision Transformers with Hierarchical Attention.\" arXiv preprint arXiv:2306.06189 (2023).\n\nimport timm\nimport torch\nimport torch.nn as nn\nfrom timm.models.registry import register_model\n\nfrom timm.models.layers import trunc_normal_, DropPath, LayerNorm2d\nimport numpy as np\nimport torch.nn.functional as F\nimport math\nimport warnings\n\n#######################\n## Codebase from YOLOv8\n## BEGINNING\n#######################\n\nclass C2f(nn.Module):\n    \"\"\"Faster Implementation of CSP Bottleneck with 2 convolutions.\"\"\"\n    \"\"\"From YOLOv8 codebase\"\"\"\n    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None):  # ch_in, ch_out, number, shortcut, groups, expansion\n        super().__init__()\n        if drop_path is None:\n            drop_path = [0.0] * n\n\n        self.c = int(c2 * e)  # hidden channels\n        self.cv1 = Conv(c1, 2 * self.c, 1, 1)\n        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)\n        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))\n\n    def forward(self, x):\n        \"\"\"Forward pass through C2f layer.\"\"\"\n        y = list(self.cv1(x).chunk(2, 1))\n        y.extend(m(y[-1]) for m in self.m)\n        return self.cv2(torch.cat(y, 1))\n\n    def forward_split(self, x):\n        \"\"\"Forward pass using split() instead of chunk().\"\"\"\n        y = list(self.cv1(x).split((self.c, self.c), 1))\n        y.extend(m(y[-1]) for m in self.m)\n        return self.cv2(torch.cat(y, 1))\n\nclass Bottleneck(nn.Module):\n    \"\"\"Standard bottleneck.\"\"\"\n\n    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0):  # ch_in, ch_out, shortcut, groups, kernels, expand\n        super().__init__()\n        c_ = int(c2 * e)  # hidden channels\n        self.cv1 = Conv(c1, c_, k[0], 1)\n        self.cv2 = Conv(c_, c2, k[1], 1, g=g)\n        self.add = shortcut and c1 == c2\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        \"\"\"'forward()' applies the YOLOv5 FPN to input data.\"\"\"\n        return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))\n\n\nclass Conv(nn.Module):\n    \"\"\"Modified to support layer fusion\"\"\"\n    default_act = nn.SiLU()  # default activation\n\n    def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):\n        super().__init__()\n\n        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)\n        if 1:\n            self.bn = torch.nn.BatchNorm2d(b)\n            torch.nn.init.constant_(self.bn.weight, bn_weight_init)\n            torch.nn.init.constant_(self.bn.bias, 0)\n        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()\n\n\n    def forward(self,x):\n        x = self.conv(x)\n        x = self.bn(x)\n        x = self.act(x)\n        return x\n\n    @torch.no_grad()\n    def switch_to_deploy(self):\n        # return 1\n        if not isinstance(self.bn, nn.Identity):\n            c, bn = self.conv, self.bn\n            w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n            w = c.weight * w[:, None, None, None]\n            b = bn.bias - bn.running_mean * bn.weight / \\\n                (bn.running_var + bn.eps)**0.5\n\n            self.conv.weight.data.copy_(w)\n            self.conv.bias = nn.Parameter(b)\n\n            self.bn = nn.Identity()\n\ndef autopad(k, p=None, d=1):  # kernel, padding, dilation\n    \"\"\"Pad to 'same' shape outputs.\"\"\"\n    if d > 1:\n        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size\n    if p is None:\n        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad\n    return p\n\n\n#######################\n## Codebase from YOLOv8\n## END\n#######################\n\ndef pixel_unshuffle(data, factor=2):\n    # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually\n    B, C, H, W = data.shape\n    return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)\n\nclass SwiGLU(nn.Module):\n    # should be more advanced, but doesnt improve results so far\n    def forward(self, x):\n        x, gate = x.chunk(2, dim=-1)\n        return F.silu(gate) * x\n\n\ndef window_partition(x, window_size):\n    \"\"\"\n    Function for partitioning image into windows and later do windowed attention\n    Args:\n        x: (B, C, H, W)\n        window_size: window size\n    Returns:\n        windows - local window features (num_windows*B, window_size*window_size, C)\n        (Hp, Wp) -  the size of the padded image\n    \"\"\"\n    B, C, H, W = x.shape\n\n    if window_size == 0 or (window_size==H and window_size==W):\n        windows = x.flatten(2).transpose(1, 2)\n        Hp, Wp = H, W\n    else:\n        pad_h = (window_size - H % window_size) % window_size\n        pad_w = (window_size - W % window_size) % window_size\n        if pad_h > 0 or pad_w > 0:\n            x = F.pad(x, (0, pad_w, 0, pad_h), mode=\"reflect\")\n        Hp, Wp = H + pad_h, W + pad_w\n\n        x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)\n        windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)\n\n    return windows, (Hp, Wp)\n\nclass Conv2d_BN(nn.Module):\n    '''\n    Conv2d + BN layer with folding capability to speed up inference\n    Can be merged with Conv() function with additional arguments\n    '''\n    def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)\n        if 1:\n            self.bn = torch.nn.BatchNorm2d(b)\n            torch.nn.init.constant_(self.bn.weight, bn_weight_init)\n            torch.nn.init.constant_(self.bn.bias, 0)\n\n    def forward(self,x):\n        x = self.conv(x)\n        x = self.bn(x)\n        return x\n\n    @torch.no_grad()\n    def switch_to_deploy(self):\n        if not isinstance(self.bn, nn.Identity):\n            c, bn = self.conv, self.bn\n            w = bn.weight / (bn.running_var + bn.eps) ** 0.5\n            w = c.weight * w[:, None, None, None]\n            b = bn.bias - bn.running_mean * bn.weight / \\\n                (bn.running_var + bn.eps)**0.5\n            self.conv.weight.data.copy_(w)\n            self.conv.bias = nn.Parameter(b)\n            self.bn = nn.Identity()\n\n\n\ndef window_reverse(windows, window_size, H, W, pad_hw):\n    \"\"\"\n    Windows to the full feature map\n    Args:\n        windows: local window features (num_windows*B, window_size, window_size, C)\n        window_size: Window size\n        H: Height of image\n        W: Width of image\n        pad_w - a tuple of image passing used in windowing step\n    Returns:\n        x: (B, C, H, W)\n\n    \"\"\"\n    # print(f\"window_reverse, windows.shape {windows.shape}\")\n    Hp, Wp = pad_hw\n    if window_size == 0 or (window_size==H and window_size==W):\n        B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))\n        x = windows.transpose(1, 2).view(B, -1, H, W)\n    else:\n        B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))\n        x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)\n        x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)\n\n        if Hp > H or Wp > W:\n            x = x[:, :, :H, :W, ].contiguous()\n\n    return x\n\n\n\nclass PosEmbMLPSwinv2D(nn.Module):\n    \"\"\"\n    2D positional embedding from Swin Transformer v2\n    Added functionality to store the positional embedding in the model and not recompute it every time\n    \"\"\"\n    def __init__(\n        self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,\n    ):\n        super().__init__()\n        self.window_size = window_size\n        self.num_heads = num_heads\n        # mlp to generate continuous relative position bias\n        self.cpb_mlp = nn.Sequential(\n            nn.Linear(2, cpb_mlp_hidden, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Linear(cpb_mlp_hidden, num_heads, bias=False),\n        )\n\n        self.grid_exists = False\n        self.seq_length = seq_length\n        self.deploy = False\n        self.num_heads = num_heads\n        self.no_log = no_log\n        self.pretrained_window_size = pretrained_window_size\n        self.relative_bias_window_size = window_size\n\n        relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,\n                                                                                                     pretrained_window_size, seq_length,\n                                                                                                     no_log)\n\n        self.register_buffer(\"relative_coords_table\", relative_coords_table)\n        self.register_buffer(\"relative_position_index\", relative_position_index)\n        self.register_buffer(\"relative_bias\", relative_bias)  # for EMA\n\n    def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):\n        # as in separate function to support window size chage after model weights loading\n        relative_coords_h = torch.arange(\n            -(window_size[0] - 1), window_size[0], dtype=torch.float32\n        )\n        relative_coords_w = torch.arange(\n            -(window_size[1] - 1), window_size[1], dtype=torch.float32\n        )\n        relative_coords_table = (\n            torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))\n            .permute(1, 2, 0)\n            .contiguous()\n            .unsqueeze(0)\n        )  # 1, 2*Wh-1, 2*Ww-1, 2\n        if pretrained_window_size[0] > 0:\n            relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1\n            relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1\n        else:\n            relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1\n            relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1\n\n        if not no_log:\n            relative_coords_table *= 8  # normalize to -8, 8\n            relative_coords_table = (\n                torch.sign(relative_coords_table)\n                * torch.log2(torch.abs(relative_coords_table) + 1.0)\n                / np.log2(8)\n            )\n\n        # get pair-wise relative position index for each token inside the window\n        coords_h = torch.arange(self.window_size[0])\n        coords_w = torch.arange(self.window_size[1])\n        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww\n        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww\n        relative_coords = (\n            coords_flatten[:, :, None] - coords_flatten[:, None, :]\n        )  # 2, Wh*Ww, Wh*Ww\n        relative_coords = relative_coords.permute(\n            1, 2, 0\n        ).contiguous()  # Wh*Ww, Wh*Ww, 2\n        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0\n        relative_coords[:, :, 1] += self.window_size[1] - 1\n        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww\n\n        relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)\n\n        self.relative_bias_window_size = window_size\n\n        return relative_coords_table, relative_position_index, relative_bias\n\n\n    def switch_to_deploy(self):\n        self.deploy = True\n        self.grid_exists = True\n\n    def forward(self, input_tensor):\n        # for efficiency, we want this forward to be folded into a single operation (sum)\n        # if resolution stays the same, then we dont need to recompute MLP layers\n\n        if not self.deploy or self.training:\n            self.grid_exists = False\n\n        #compare if all elements in self.window_size list match those in self.relative_bias_window_size\n        if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):\n            relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,\n                                                                                                        self.pretrained_window_size, self.seq_length,\n                                                                                                        self.no_log)\n\n            self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)\n            self.relative_position_index = relative_position_index.to(self.relative_position_index.device)\n            self.relative_bias = relative_bias.to(self.relative_bias.device)\n\n        if self.deploy and self.grid_exists:\n            input_tensor = input_tensor + self.relative_bias\n            return input_tensor\n\n        if 1:\n            self.grid_exists = True\n\n            relative_position_bias_table = self.cpb_mlp(\n                self.relative_coords_table\n            ).view(-1, self.num_heads)\n            relative_position_bias = relative_position_bias_table[\n                self.relative_position_index.view(-1)\n            ].view(\n                self.window_size[0] * self.window_size[1],\n                self.window_size[0] * self.window_size[1],\n                -1,\n            )  # Wh*Ww,Wh*Ww,nH\n\n            relative_position_bias = relative_position_bias.permute(\n                2, 0, 1\n            ).contiguous()  # nH, Wh*Ww, Wh*Ww\n            relative_position_bias = 16 * torch.sigmoid(relative_position_bias)\n\n            self.relative_bias = relative_position_bias.unsqueeze(0)\n\n        input_tensor = input_tensor + self.relative_bias\n        return input_tensor\n\n\nclass GRAAttentionBlock(nn.Module):\n    def __init__(self, window_size, dim_in, dim_out,\n                 num_heads, drop_path=0., qk_scale=None, qkv_bias=False,\n                 norm_layer=nn.LayerNorm, layer_scale=None,\n                  use_swiglu=True,\n                  subsample_ratio=1, dim_ratio=1, conv_base=False,\n                  do_windowing=True, multi_query=False, use_shift=0,\n                  cpb_mlp_hidden=512, conv_groups_ratio=0):\n        '''\n        Global Resolution Attention Block , see README for details\n        Attention with subsampling to get a bigger receptive field for attention\n        conv_base - use conv2d instead of avgpool2d for downsample / upsample\n\n\n        '''\n        super().__init__()\n\n        self.shift_size=window_size//2 if use_shift else 0\n\n        self.do_windowing = do_windowing\n        self.subsample_ratio = subsample_ratio\n\n\n\n        if do_windowing:\n            if conv_base:\n                    self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()\n\n\n                    self.downsample_mixer = nn.Identity()\n                    self.upsample_mixer = nn.Identity()\n                    self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()\n            else:\n                self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()\n                self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()\n                self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()\n                self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()\n\n\n        # in case there is no downsampling conv we want to have it separately\n        # will help with information propagation between windows\n        if subsample_ratio == 1:\n            # conv_groups_ratio=0\n            self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)\n            # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)\n            # self.pre_conv_act = nn.ReLU6()\n            #for simplicity:\n            self.pre_conv_act = nn.Identity()\n            if conv_groups_ratio == -1:\n                self.pre_conv = nn.Identity()\n                self.pre_conv_act = nn.Identity()\n\n        self.window_size = window_size\n\n        self.norm1 = norm_layer(dim_in)\n\n        self.attn = WindowAttention(\n            dim_in,\n            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            resolution=window_size,\n            seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,\n            shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden)\n\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n        use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]\n        self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in))  if use_layer_scale else 1\n\n        ### mlp layer\n        mlp_ratio = 4\n        self.norm2 = norm_layer(dim_in)\n        mlp_hidden_dim = int(dim_in * mlp_ratio)\n\n        activation = nn.GELU if not use_swiglu else SwiGLU\n        mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim\n\n        self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)\n\n        self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1\n        self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n\n    def forward(self, x):\n        skip_connection = x\n        attn_mask = None\n\n        # in case there is no downsampling conv we want to have it separately\n        # will help with information propagation\n        if self.subsample_ratio == 1:\n            x = self.pre_conv_act(self.pre_conv(x)) + skip_connection\n\n        if self.do_windowing:\n            # performing windowing if required\n            x = self.downsample_op(x)\n            x = self.downsample_mixer(x)\n\n            if self.window_size>0:\n                H, W = x.shape[2], x.shape[3]\n\n            if self.shift_size > 0 and H>self.window_size and W>self.window_size:\n                # @swin like cyclic shift, doesnt show better performance\n                x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))\n\n            x, pad_hw = window_partition(x, self.window_size)\n\n            if self.shift_size > 0 and H>self.window_size and W>self.window_size:\n                # set atten matrix to have -100 and the top right square\n                # attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0\n                # calculate attention mask for SW-MSA\n                # not used in final version, can be useful for some cases especially for high res\n                H, W = pad_hw\n                img_mask = torch.zeros((1, H, W, 1), device=x.device)  # 1 H W 1\n                h_slices = (slice(0, -self.window_size),\n                            slice(-self.window_size, -self.shift_size),\n                            slice(-self.shift_size, None))\n                w_slices = (slice(0, -self.window_size),\n                            slice(-self.window_size, -self.shift_size),\n                            slice(-self.shift_size, None))\n                cnt = 0\n                for h in h_slices:\n                    for w in w_slices:\n                        img_mask[:, h, w, :] = cnt\n                        cnt += 1\n                img_mask = img_mask.transpose(1,2).transpose(1,3)\n                mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1\n\n                mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size)\n                attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n                attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n\n        # window attention\n        x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W\n        # mlp layer\n        x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))\n\n        if self.do_windowing:\n            if self.window_size > 0:\n                x = window_reverse(x, self.window_size, H, W, pad_hw)\n\n            # reverse cyclic shift\n            if self.shift_size > 0 and H>self.window_size and W>self.window_size:\n                # @swin like cyclic shift, not tested\n                x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))\n\n            x = self.upsample_mixer(x)\n            x = self.upsample_op(x)\n\n\n            if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:\n                x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode=\"reflect\")\n        # need to add skip connection because downsampling and upsampling will break residual connection\n        # 0.5 is needed to make sure that the skip connection is not too strong\n        # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection\n        x = 0.5 * x + 0.5 * skip_connection\n        return x\n\n\n\n\nclass MultiResolutionAttention(nn.Module):\n    \"\"\"\n    MultiResolutionAttention (MRA) module\n    The idea is to use multiple attention blocks with different resolution\n    Feature maps are downsampled / upsampled for each attention block on different blocks\n    Every attention block supports windowing\n    \"\"\"\n\n    def __init__(self, window_size, sr_ratio,\n                 dim, dim_ratio, num_heads,\n                 do_windowing=True,\n                 layer_scale=1e-5, norm_layer=nn.LayerNorm,\n                 drop_path = 0, qkv_bias=False, qk_scale=1.0,\n                 use_swiglu=True, multi_query=False, conv_base=False,\n                 use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None:\n        \"\"\"\n        Args:\n            input_resolution: input image resolution\n            window_size: window size\n            compression_ratio: compression ratio\n            max_depth: maximum depth of the GRA module\n            use_shift: do window shifting\n        \"\"\"\n        super().__init__()\n\n        depth = len(sr_ratio)\n\n        self.attention_blocks = nn.ModuleList()\n\n\n        for i in range(depth):\n            subsample_ratio = sr_ratio[i]\n            if len(window_size) > i:\n                window_size_local = window_size[i]\n            else:\n                window_size_local = window_size[0]\n\n            self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,\n                                            dim_in=dim, dim_out=dim, num_heads=num_heads,\n                                            qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,\n                                            layer_scale=layer_scale, drop_path=drop_path,\n                                            use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,\n                                            do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base,\n                                            use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio),\n                                        )\n\n    def forward(self, x):\n\n        for attention_block in self.attention_blocks:\n            x = attention_block(x)\n\n        return x\n\n\n\nclass Mlp(nn.Module):\n    \"\"\"\n    Multi-Layer Perceptron (MLP) block\n    \"\"\"\n\n    def __init__(self,\n                 in_features,\n                 hidden_features=None,\n                 out_features=None,\n                 act_layer=nn.GELU,\n                 use_swiglu=True,\n                 drop=0.):\n        \"\"\"\n        Args:\n            in_features: input features dimension.\n            hidden_features: hidden features dimension.\n            out_features: output features dimension.\n            act_layer: activation function.\n            drop: dropout rate.\n        \"\"\"\n\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=False)\n\n    def forward(self, x):\n        x_size = x.size()\n        x = x.view(-1, x_size[-1])\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.fc2(x)\n        x = x.view(x_size)\n        return x\n\nclass Downsample(nn.Module):\n    \"\"\"\n    Down-sampling block\n    Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 shuffle = False,\n                 ):\n        \"\"\"\n        Args:\n            dim: feature size dimension.\n            shuffle: idea with\n            keep_dim: bool argument for maintaining the resolution.\n        \"\"\"\n\n        super().__init__()\n        dim_out = 2 * dim\n\n        if shuffle:\n            self.norm = lambda x: pixel_unshuffle(x, factor=2)\n            self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)\n            # pixel unshuffleging works well but doesnt provide any speedup\n        else:\n            # removed layer norm for better, in this formulation we are getting 10% better speed\n            # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension\n            # therefore we remove it compared to the original implementation in FasterViT\n            self.norm = nn.Identity()\n            self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)\n\n\n    def forward(self, x):\n        x = self.norm(x)\n        x = self.reduction(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    Patch embedding block\n    Used to convert image into an initial set of feature maps with lower resolution\n    \"\"\"\n\n    def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):\n        \"\"\"\n        Args:\n            in_chans: number of input channels.\n            in_dim: intermediate feature size dimension to speed up stem.\n            dim: final stem channel number\n            shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field\n        \"\"\"\n\n        super().__init__()\n        # shuffle_down = False\n        if not shuffle_down:\n            self.proj = nn.Identity()\n            self.conv_down = nn.Sequential(\n                Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),\n                nn.ReLU(),\n                Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),\n                nn.ReLU()\n                )\n        else:\n            self.proj = lambda x: pixel_unshuffle(x, factor=4)\n            self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),\n                                           nn.ReLU(),\n                                           )\n\n    def forward(self, x):\n        x = self.proj(x)\n        x = self.conv_down(x)\n        return x\n\n\n\nclass ConvBlock(nn.Module):\n    \"\"\"\n    Convolutional block, used in first couple of stages\n    Experimented with plan resnet-18 like modules, they are the best in terms of throughput\n    Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)\n    \"\"\"\n    def __init__(self, dim,\n                 drop_path=0.,\n                 layer_scale=None,\n                 kernel_size=3,\n                 ):\n        super().__init__()\n\n        self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)\n        self.act1 = nn.GELU()\n\n        self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)\n\n        self.layer_scale = layer_scale\n        if layer_scale is not None and type(layer_scale) in [int, float]:\n            self.gamma = nn.Parameter(layer_scale * torch.ones(dim))\n            self.layer_scale = True\n        else:\n            self.layer_scale = False\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        input = x\n\n        x = self.conv1(x)\n        x = self.act1(x)\n        x = self.conv2(x)\n\n        if self.layer_scale:\n            x = x * self.gamma.view(1, -1, 1, 1)\n        x = input + self.drop_path(x)\n        return x\n\n\nclass WindowAttention(nn.Module):\n    # Windowed Attention from SwinV2\n    # use a MLP trick to deal with various input image resolutions, then fold it to improve speed\n\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,\n                 seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512):\n        # taken from EdgeViT and tweaked with attention bias.\n        super().__init__()\n        if not dim_out: dim_out = dim\n        self.shift_size = shift_size\n        self.multi_query = multi_query\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.head_dim = dim // num_heads\n\n        self.dim_internal = dim\n\n        self.scale = qk_scale or head_dim ** -0.5\n        if not multi_query:\n            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        else:\n            self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)\n\n        self.proj = nn.Linear(dim, dim_out, bias=False)\n        # attention positional bias\n        self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],\n                                              pretrained_window_size=[resolution, resolution],\n                                              num_heads=num_heads,\n                                              seq_length=seq_length,\n                                              cpb_mlp_hidden=cpb_mlp_hidden)\n\n        self.resolution = resolution\n\n    def forward(self, x, attn_mask = None):\n        B, N, C = x.shape\n\n        if not self.multi_query:\n            qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n            q, k, v = qkv[0], qkv[1], qkv[2]\n        else:\n            qkv = self.qkv(x)\n            (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)\n\n            q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n            k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)\n            v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n\n        attn = self.pos_emb_funct(attn)\n\n        #add window shift\n        if attn_mask is not None:\n            nW = attn_mask.shape[0]\n            attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0)\n            attn = attn.view(-1, self.num_heads, N, N)\n\n        attn = attn.softmax(dim=-1)\n        x = (attn @ v).transpose(1, 2).reshape(B, -1, C)\n        x = self.proj(x)\n        return x\n\n\n\nclass ERADIOLayer(nn.Module):\n    \"\"\"\n    E-RADIO Layer\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 depth,\n                 num_heads,\n                 window_size,\n                 conv=False,\n                 downsample=True,\n                 mlp_ratio=4.,\n                 qkv_bias=False,\n                 qk_scale=None,\n                 norm_layer=nn.LayerNorm,\n                 drop_path=0.,\n                 layer_scale=None,\n                 layer_scale_conv=None,\n                 sr_dim_ratio=1,\n                 sr_ratio=1,\n                 multi_query=False,\n                 use_swiglu=True,\n                 yolo_arch=False,\n                 downsample_shuffle=False,\n                 conv_base=False,\n                 use_shift=False,\n                 cpb_mlp_hidden=512,\n                 conv_groups_ratio=0,\n                 verbose: bool = True,\n\n    ):\n        \"\"\"\n        Args:\n            dim: feature size dimension.\n            depth: number of layers in each stage.\n            input_resolution: input image resolution.\n            window_size: window size in each stage.\n            downsample: bool argument for down-sampling.\n            mlp_ratio: MLP ratio.\n            num_heads: number of heads in each stage.\n            qkv_bias: bool argument for query, key, value learnable bias.\n            qk_scale: bool argument to scaling query, key.\n            drop: dropout rate.\n            attn_drop: attention dropout rate.\n            drop_path: drop path rate.\n            norm_layer: normalization layer.\n            layer_scale: layer scaling coefficient.\n            use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution)\n            conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention\n        \"\"\"\n\n        super().__init__()\n        self.conv = conv\n        self.yolo_arch=False\n        self.verbose = verbose\n        if conv:\n            if not yolo_arch:\n                self.blocks = nn.ModuleList([\n                    ConvBlock(dim=dim,\n                            drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                            layer_scale=layer_scale_conv)\n                    for i in range(depth)])\n                self.blocks = nn.Sequential(*self.blocks)\n            else:\n                self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)\n                self.yolo_arch=True\n        else:\n            if not isinstance(window_size, list): window_size = [window_size]\n            self.window_size = window_size[0]\n            self.do_single_windowing = True\n            if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]\n            self.sr_ratio = sr_ratio\n            if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:\n                self.do_single_windowing = False\n                do_windowing = True\n            else:\n                self.do_single_windowing = True\n                do_windowing = False\n\n            #for v2_2\n            if conv_groups_ratio != -1:\n                self.do_single_windowing = False\n                do_windowing = True\n\n            self.blocks = nn.ModuleList()\n            for i in range(depth):\n                self.blocks.append(\n                    MultiResolutionAttention(window_size=window_size,\n                                             sr_ratio=sr_ratio,\n                                             dim=dim,\n                                             dim_ratio = sr_dim_ratio,\n                                             num_heads=num_heads,\n                                             norm_layer=norm_layer,\n                                             drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n                                             layer_scale=layer_scale,\n                                             qkv_bias=qkv_bias,\n                                             qk_scale=qk_scale,\n                                             use_swiglu=use_swiglu,\n                                             do_windowing=do_windowing,\n                                             multi_query=multi_query,\n                                             conv_base=conv_base,\n                                             cpb_mlp_hidden=cpb_mlp_hidden,\n                                             use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True    ,\n                                             conv_groups_ratio=conv_groups_ratio,\n                    ))\n            self.blocks = nn.Sequential(*self.blocks)\n\n        self.transformer = not conv\n        self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)\n\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n\n        # do padding for transforemr\n        interpolate = True\n        if self.transformer and interpolate:\n            # Windowed Attention will split feature map into windows with the size of window_size x window_size\n            # if the resolution is not divisible by window_size, we need to interpolate the feature map\n            # can be done via padding, but doing so after training hurts the model performance.\n            # interpolation affects the performance as well, but not as much as padding\n            if isinstance(self.window_size, list) or isinstance(self.window_size, tuple):\n                current_max_window_size = max(self.window_size)\n            else:\n                current_max_window_size = self.window_size\n\n            max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio])\n            if H % max_window_size != 0 or W % max_window_size != 0:\n                new_h = int(np.ceil(H/max_window_size)*max_window_size)\n                new_w = int(np.ceil(W/max_window_size)*max_window_size)\n                x = F.interpolate(x, size=(new_h, new_w), mode='nearest')\n                if self.verbose:\n                    warnings.warn(f\"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.\")\n\n\n        if self.transformer and self.do_single_windowing:\n            H, W = x.shape[2], x.shape[3]\n            x, pad_hw = window_partition(x, self.window_size)\n\n        #run main blocks\n        x = self.blocks(x)\n\n        if self.transformer and self.do_single_windowing:\n            x = window_reverse(x, self.window_size, H, W, pad_hw)\n\n        if self.transformer and interpolate:\n            #lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution.\n            x = F.interpolate(x, size=(H, W), mode='nearest')\n\n        if self.downsample is None:\n            return x, x\n\n        return self.downsample(x), x  # changing to output pre downsampled features\n\n\nclass InterpolateLayer(nn.Module):\n    def __init__(self, size=None, scale_factor=None, mode='nearest'):\n        super(InterpolateLayer, self).__init__()\n        self.size = size\n        self.scale_factor = scale_factor\n        self.mode = mode\n\n    def forward(self, x):\n        return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)\n\n\nclass HiResNeck(nn.Module):\n    \"\"\"\n    The block is used to output dense features from all stages\n    Otherwise, by default, only the last stage features are returned with E-RADIO\n    \"\"\"\n    def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):\n\n        '''\n        Hi Resolution neck to support output of high res features that are useful for dense tasks.\n        depths - total number of layers in the base model\n        neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc.\n                            earlier layers result in higher resolution features at the cost of compute\n        full_features_head_dim - number of channels in the dense features head\n        '''\n        super().__init__()\n        # create feature projection layers for segmentation output\n        self.neck_features_proj = nn.ModuleList()\n        self.neck_start_stage = neck_start_stage\n        upsample_ratio = 1\n        for i in range(len(depths)):\n            level_n_features_output = int(dim * 2 ** i)\n\n            if self.neck_start_stage > i: continue\n\n            if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:\n                feature_projection = nn.Sequential()\n                if False:\n                    feature_projection.add_module(\"norm\",nn.BatchNorm2d(level_n_features_output)) #fast, but worse\n                    feature_projection.add_module(\"dconv\", nn.ConvTranspose2d(level_n_features_output,\n                                                                            full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))\n                else:\n                    # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio\n                    # print(\"upsample ratio\", upsample_ratio, level_n_features_output, level_n_features_output)\n                    feature_projection.add_module(\"upsample\", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest'))\n                    feature_projection.add_module(\"conv1\", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output))\n                    feature_projection.add_module(\"norm\",nn.BatchNorm2d(level_n_features_output))\n                    # B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio\n                    feature_projection.add_module(\"conv2\", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0))\n            else:\n                feature_projection = nn.Sequential()\n\n            self.neck_features_proj.append(feature_projection)\n\n            if i>0 and downsample_enabled[i]:\n                upsample_ratio *= 2\n\n    def forward(self, x, il_level=-1, full_features=None):\n        if self.neck_start_stage > il_level:\n            return full_features\n\n        if full_features is None:\n            full_features = self.neck_features_proj[il_level - self.neck_start_stage](x)\n        else:\n            #upsample torch tensor x to match full_features size, and add to full_features\n            feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)\n            if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:\n                feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))\n            full_features = full_features + feature_projection\n        return full_features\n\nclass ERADIO(nn.Module):\n    \"\"\"\n    Efficient RADIO\n    \"\"\"\n\n    def __init__(self,\n                 dim,\n                 in_dim,\n                 depths,\n                 window_size,\n                 mlp_ratio,\n                 num_heads,\n                 drop_path_rate=0.2,\n                 in_chans=3,\n                 num_classes=1000,\n                 qkv_bias=False,\n                 qk_scale=None,\n                 layer_scale=None,\n                 layer_scale_conv=None,\n                 layer_norm_last=False,\n                 sr_ratio = [1, 1, 1, 1],\n                 max_depth = -1,\n                 conv_base=False,\n                 use_swiglu=False,\n                 multi_query=False,\n                 norm_layer=nn.LayerNorm,\n                 drop_uniform=False,\n                 yolo_arch=False,\n                 shuffle_down=False,\n                 downsample_shuffle=False,\n                 return_full_features=False,\n                 full_features_head_dim=128,\n                 neck_start_stage=1,\n                 use_neck=False,\n                 use_shift=False,\n                 cpb_mlp_hidden=512,\n                 conv_groups_ratio=0,\n                 verbose: bool = False,\n                 **kwargs):\n        \"\"\"\n        Args:\n            dim: feature size dimension.\n            depths: number of layers in each stage.\n            window_size: window size in each stage.\n            mlp_ratio: MLP ratio.\n            num_heads: number of heads in each stage.\n            drop_path_rate: drop path rate.\n            in_chans: number of input channels.\n            num_classes: number of classes.\n            qkv_bias: bool argument for query, key, value learnable bias.\n            qk_scale: bool argument to scaling query, key.\n            drop_rate: dropout rate.\n            attn_drop_rate: attention dropout rate.\n            norm_layer: normalization layer.\n            layer_scale: layer scaling coefficient.\n            return_full_features: output dense features as well as logits\n            full_features_head_dim: number of channels in the dense features head\n            neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0\n                                for 224 resolution, the output of the stage before downsample:\n                                stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7\n            use_neck: even for summarization embedding use neck\n            use_shift: SWIN like window shifting but without masking attention\n            conv_groups_ratio: will be used for conv blocks where there is no multires attention,\n                                if 0 then normal conv,\n                                if 1 then channels are independent,\n                                if -1 then no conv at all\n\n        \"\"\"\n        super().__init__()\n\n        num_features = int(dim * 2 ** (len(depths) - 1))\n        self.num_classes = num_classes\n        self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)\n        # set return_full_features true if we want to return full features from all stages\n        self.return_full_features = return_full_features\n        self.use_neck = use_neck\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        if drop_uniform:\n            dpr = [drop_path_rate for x in range(sum(depths))]\n\n        if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)\n\n        self.levels = nn.ModuleList()\n        for i in range(len(depths)):\n            conv = True if (i == 0 or i == 1) else False\n\n            level = ERADIOLayer(dim=int(dim * 2 ** i),\n                                   depth=depths[i],\n                                   num_heads=num_heads[i],\n                                   window_size=window_size[i],\n                                   mlp_ratio=mlp_ratio,\n                                   qkv_bias=qkv_bias,\n                                   qk_scale=qk_scale,\n                                   conv=conv,\n                                   drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],\n                                   downsample=(i < len(depths) - 1),\n                                   layer_scale=layer_scale,\n                                   layer_scale_conv=layer_scale_conv,\n                                   sr_ratio=sr_ratio[i],\n                                   use_swiglu=use_swiglu,\n                                   multi_query=multi_query,\n                                   norm_layer=norm_layer,\n                                   yolo_arch=yolo_arch,\n                                   downsample_shuffle=downsample_shuffle,\n                                   conv_base=conv_base,\n                                   cpb_mlp_hidden=cpb_mlp_hidden,\n                                   use_shift=use_shift,\n                                   conv_groups_ratio=conv_groups_ratio,\n                                   verbose=verbose)\n\n            self.levels.append(level)\n\n        if self.return_full_features or self.use_neck:\n            #num_heads\n            downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))]\n            self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled)\n\n        self.switched_to_deploy = False\n\n        self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)\n        self.avgpool = nn.AdaptiveAvgPool2d(1)\n        self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, LayerNorm2d):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.BatchNorm2d):\n            nn.init.ones_(m.weight)\n            nn.init.zeros_(m.bias)\n\n    @torch.jit.ignore\n    def no_weight_decay_keywords(self):\n        return {'rpb'}\n\n    def forward_features(self, x):\n        _, _, H, W = x.shape\n        if H % 32 != 0 or W % 32 != 0:\n            raise ValueError(f\"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}\")\n        x = self.patch_embed(x)\n        full_features = None\n        for il, level in enumerate(self.levels):\n            x, pre_downsample_x = level(x)\n\n            if self.return_full_features or self.use_neck:\n                full_features = self.high_res_neck(pre_downsample_x, il, full_features)\n\n        # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)\n        x = self.norm(x) # new version for\n\n        if not self.return_full_features:\n            return x, None\n\n        return x, full_features\n\n    def forward(self, x):\n        x, full_features = self.forward_features(x)\n\n        x = self.avgpool(x)\n        x = torch.flatten(x, 1)\n\n        x = self.head(x)\n        if full_features is not None:\n            return x, full_features\n        return x\n\n    def switch_to_deploy(self):\n        '''\n        A method to perform model self-compression\n        merges BN into conv layers\n        converts MLP relative positional bias into precomputed buffers\n        '''\n        if not self.switched_to_deploy:\n            for level in [self.patch_embed, self.levels, self.head]:\n                for module in level.modules():\n                    if hasattr(module, 'switch_to_deploy'):\n                        module.switch_to_deploy()\n        self.switched_to_deploy = True\n\n\n    def change_window_size(self, new_window_size):\n        \"\"\"\n        E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,\n        especially in cases of uneven partitioning of the feature maps.\n        E-RADIO allows for the adjustment of the window size after training,\n        making it adaptable to different input image resolutions.\n        The recommended values for window size based on input resolution are as follows:\n\n        Input Resolution | Window Size\n        224 | 7\n        256 | 8\n        386 | 12\n        512 | 16\n        Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be\n        img_res/16/2\n        for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.\n        Manual way to change resolution -> model.change_window_size(resolution)\n        \"\"\"\n        window_size = new_window_size\n        print(f\"Setting window size to {window_size}\")\n        for module in self.modules():\n            if hasattr(module, \"window_size\"):\n                # check if tuple or a number\n                if isinstance(module.window_size, tuple):\n                    if module.window_size[0] != window_size:\n                        module.window_size = (window_size, window_size)\n                elif isinstance(module.window_size, list):\n                    if module.window_size[0] != window_size:\n                        module.window_size = [window_size, window_size]\n                else:\n                    module.window_size = window_size\n\n\n    def set_optimal_window_size(self, image_dim, max_window_size = 16):\n        \"\"\"\n        Using hand picked window size for various resolutions.\n\n        E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,\n        especially in cases of uneven partitioning of the feature maps.\n        E-RADIO allows for the adjustment of the window size after training,\n        making it adaptable to different input image resolutions.\n        The recommended values for window size based on input resolution are as follows:\n\n        Input Resolution | Window Size\n        224 | 7\n        256 | 8\n        386 | 12\n        512 | 16\n        Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be\n        img_res/16/2\n        for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.\n        Manual way to change resolution -> model.change_window_size(resolution)\n\n        \"\"\"\n        # import math\n\n        def divisorGenerator(n):\n            large_divisors = []\n            for i in range(1, int(math.sqrt(n) + 1)):\n                if n % i == 0:\n                    yield i\n                    if i*i != n:\n                        large_divisors.append(n / i)\n            for divisor in reversed(large_divisors):\n                yield divisor\n\n        if isinstance(image_dim, list) or isinstance(image_dim, tuple):\n            image_dim = min(image_dim)\n\n        # we do windowed attention in the 3rd stage for the first time, therefore //16,\n        # we do subsampled attention with downsample by 2 so need to get //32 actually\n        # ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc\n        all_divisors = np.array(list(divisorGenerator(image_dim//32)))\n        new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))\n\n        # for image_dim in [128, 224, 256, 384, 512, 768, 1024]:\n        #     all_divisors = np.array(list(divisorGenerator(image_dim//32)))\n        #     new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))\n        #     print(f\"Setting window size to {new_window_size} for image resolution {image_dim}\")\n\n        self.change_window_size(new_window_size = new_window_size)\n\n\n@register_model\ndef eradio_large_fullres_ws16(pretrained=False, **kwargs):\n    model = ERADIO(\n        depths=[3, 3, 5, 5],\n        num_heads=[2, 4, 8, 16],\n        window_size=[None, None, [16, 16], 16],\n        dim=192,\n        in_dim=64,\n        mlp_ratio=4,\n        drop_path_rate=0.0,\n        sr_ratio=[1, 1, [2, 1], 1],\n        use_swiglu=False,\n        yolo_arch=True,\n        shuffle_down=False,\n        conv_base=True,\n        use_neck=True,\n        full_features_head_dim=1536,\n        neck_start_stage=2,\n        **kwargs,\n    )\n    if pretrained:\n        model.load_state_dict(torch.load(pretrained)[\"state_dict\"])\n    return model\n\n\n@register_model\ndef eradio_xxxtiny(pretrained=False, **kwargs):  # ,\n    model = ERADIO(\n        depths=[1, 3, 4, 5],\n        num_heads=[2, 4, 8, 16],\n        window_size=[None, None, [16, 16], 16],\n        dim=32,\n        in_dim=32,\n        mlp_ratio=4,\n        drop_path_rate=0.0,\n        sr_ratio=[1, 1, [2, 1], 1],\n        use_swiglu=False,\n        yolo_arch=True,\n        shuffle_down=False,\n        conv_base=True,\n        use_neck=True,\n        full_features_head_dim=256,\n        neck_start_stage=2,\n        **kwargs,\n    )\n    if pretrained:\n        model.load_state_dict(torch.load(pretrained))\n    return model\n\n@register_model\ndef eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):\n    model = ERADIO(depths=[1, 3, 4, 5],\n        num_heads=[2, 4, 8, 16],\n        window_size=[None, None, [12, 12], 12],\n        dim=32,\n        in_dim=32,\n        mlp_ratio=4,\n        drop_path_rate=0.0,\n        sr_ratio=[1, 1, [2, 1], 1],\n        use_swiglu=False,\n        downsample_shuffle=False,\n        yolo_arch=True,\n        shuffle_down=False,\n        cpb_mlp_hidden=64,\n        use_neck=True,\n        full_features_head_dim=256,\n        neck_start_stage=2,\n        conv_groups_ratio = 1,\n        **kwargs)\n    if pretrained:\n        model.load_state_dict(torch.load(pretrained)[\"state_dict\"])\n    return model\n\n\n@register_model\ndef eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):\n    model = ERADIO(depths=[1, 3, 4, 5],\n        num_heads=[2, 4, 8, 16],\n        window_size=[None, None, [16, 16], 16],\n        dim=32,\n        in_dim=32,\n        mlp_ratio=4,\n        drop_path_rate=0.0,\n        sr_ratio=[1, 1, [2, 1], 1],\n        use_swiglu=False,\n        downsample_shuffle=False,\n        yolo_arch=True,\n        shuffle_down=False,\n        cpb_mlp_hidden=64,\n        use_neck=True,\n        full_features_head_dim=256,\n        neck_start_stage=1,\n        conv_groups_ratio = 1,\n        **kwargs)\n    if pretrained:\n        model.load_state_dict(torch.load(pretrained)[\"state_dict\"])\n    return model\n\n@register_model\ndef eradio(pretrained=False, **kwargs):\n    return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/extra_models.py",
    "content": "from distutils.version import LooseVersion\nfrom types import MethodType\nfrom typing import List, Optional, Tuple, Union\nimport warnings\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom timm.models.registry import register_model\nfrom timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n\nfrom .forward_intermediates import forward_intermediates\nfrom .input_conditioner import InputConditioner\n\n_has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')\n\n\nclass PaliGemmaWrapper(nn.Module):\n    def __init__(self, vis_model: nn.Module, embed_dim: int):\n        super().__init__()\n\n        self.vis_model = vis_model\n        self.embed_dim = embed_dim\n\n    @property\n    def patch_size(self):\n        return self.vis_model.embeddings.patch_size\n\n    @property\n    def blocks(self):\n        return self.vis_model.encoder.layers\n\n    @property\n    def embed_dim(self):\n        return self.vis_model.embeddings.embed_dim\n\n    def forward(self, x: torch.Tensor):\n        outputs = self.vis_model(\n            x,\n            return_dict=False,\n            interpolate_pos_encoding=True,\n        )\n\n        features = outputs[0].to(torch.float32)\n\n        summary = features.mean(dim=1)\n\n        return summary, features\n\n    def forward_features(self, x: torch.Tensor):\n        return self(x)\n\n\ndef _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):\n    from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version\n\n    if LooseVersion(tx_version) > LooseVersion('4.44.2'):\n        warnings.warn(f'Your transformers version \"{tx_version}\" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')\n\n    extra_args = dict()\n\n    if dtype is not None:\n        extra_args['torch_dtype'] = dtype\n        rev = str(dtype).split('.')[-1]\n        extra_args['revision'] = rev\n\n    model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)\n\n    vis_model = model.vision_tower.vision_model\n\n    vis_model = PaliGemmaWrapper(vis_model, embed_dim)\n\n    return vis_model\n\n@register_model\ndef paligemma_896_student(**kwargs):\n    model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)\n\n    return model\n\n\ndef dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:\n    B, N, C = x.shape\n    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n    q, k, v = qkv[0], qkv[1], qkv[2]\n    x = F.scaled_dot_product_attention(\n        q, k, v,\n        is_causal=False,\n        dropout_p=self.attn_drop.p if self.training else 0.,\n        scale=self.scale,\n    )\n    x = x.transpose(1, 2).reshape(B, N, C)\n    x = self.proj(x)\n    x = self.proj_drop(x)\n    return x\n\ndef _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):\n    if cache_dir:\n        torch.hub.set_dir(cache_dir)\n    model: nn.Module = torch.hub.load(\n        'facebookresearch/dinov2',\n        dino_v2_model,\n        pretrained=pretrained,\n        # **kwargs,\n    )\n\n    if _has_torch_sdpa:\n        for n, m in model.named_modules():\n            if n.endswith('.attn'):\n                m.forward = MethodType(dv2_sdpa, m)\n\n    return model\n\nclass DinoWrapper(nn.Module):\n    def __init__(self, dino_model: nn.Module):\n        super().__init__()\n\n        self.inner = dino_model\n        dino_model.blocks = nn.Sequential(*dino_model.blocks)\n\n    @property\n    def embed_dim(self):\n        return self.inner.embed_dim\n\n    @property\n    def patch_size(self):\n        return self.inner.patch_size\n\n    @property\n    def num_cls_tokens(self):\n        return getattr(self.inner, 'num_tokens', 1)\n\n    @property\n    def num_registers(self):\n        return getattr(self.inner, 'num_register_tokens', 0)\n\n    @property\n    def num_summary_tokens(self):\n        return self.num_cls_tokens + self.num_registers\n\n    @property\n    def blocks(self):\n        return self.inner.blocks\n\n    def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:\n        parts = self.inner.forward_features(*args, **kwargs)\n\n        cls_token = parts['x_norm_clstoken']\n        features = parts['x_norm_patchtokens']\n\n        return cls_token, features\n\n    def forward_features(self, x: torch.Tensor):\n        x = self.inner.prepare_tokens_with_masks(x)\n        x = self.inner.blocks(x)\n        x_norm = self.inner.norm(x)\n\n        return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]\n\n    def patchify(self, x: torch.Tensor) -> torch.Tensor:\n        return self.inner.prepare_tokens_with_masks(x)\n\n    def forward_intermediates(self,\n        x: torch.Tensor,\n        norm: bool = False,\n        **kwargs,\n    ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:\n        return forward_intermediates(\n            self,\n            patch_extractor=self.inner.prepare_tokens_with_masks,\n            num_summary_tokens=self.num_summary_tokens,\n            num_cls_tokens=self.num_cls_tokens,\n            norm=self.inner.norm if norm else lambda y: y,\n            x=x,\n            **kwargs,\n        )\n\n\ndef _dino_student(arch: str, **kwargs):\n    from . import dinov2_arch\n\n    factory = getattr(dinov2_arch, arch)\n    model = factory()\n\n    model = DinoWrapper(model)\n\n    conditioner = InputConditioner(\n        input_scale=1.0,\n        norm_mean=IMAGENET_DEFAULT_MEAN,\n        norm_std=IMAGENET_DEFAULT_STD,\n    )\n\n    model.input_conditioner = conditioner\n\n    return model\n\n\n@register_model\ndef dino_v2_l_student(**kwargs):\n    return _dino_student('dinov2_vitl14_reg', **kwargs)\n\n@register_model\ndef dino_v2_g_student(**kwargs):\n    return _dino_student('dinov2_vitg14_reg', **kwargs)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/extra_timm_models.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport math\nimport warnings\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom timm.models import register_model\nfrom timm.models.vision_transformer import (\n    VisionTransformer,\n    _create_vision_transformer as _timm_create_vision_transformer,\n    Mlp,\n    Block,\n    LayerScale as TIMMLayerScale,\n)\n\n# Import these to also register them\nfrom . import dinov2_arch\n\n\n@register_model\ndef vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Tiny (Vit-Ti/16)\n    \"\"\"\n    model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)\n    model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))\n    return model\n\n\n@register_model\ndef vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Small (ViT-S/16)\n    \"\"\"\n    model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)\n    model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))\n    return model\n\n\n@register_model\ndef vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)\n    model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))\n    return model\n\n\n@register_model\ndef vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_args = dict(\n        patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,\n        reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14\n    )\n    model = _create_vision_transformer(\n        'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))\n    return model\n\n\n@register_model\ndef vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    name = 'vit_large_patch14_reg4_dinov2'\n    model_args = dict(\n        patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,\n        reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14\n    )\n    model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))\n\n    return model\n\n@register_model\ndef vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).\n    \"\"\"\n    model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)\n    if pretrained:\n        # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose\n        model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))\n    else:\n        model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))\n    return model\n\n\n@register_model\ndef vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).\n    \"\"\"\n    model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)\n\n    for m in model.modules():\n        if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):\n            m.norm = nn.LayerNorm(m.fc1.out_features)\n\n    return model\n\n\n@register_model\ndef vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).\n    \"\"\"\n    model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)\n    model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))\n    if scaled_ln:\n        _apply_scaled_ln(model)\n    return model\n\n\n@register_model\ndef vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:\n    model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)\n    model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))\n    return model\n\n\ndef _create_vision_transformer(*args, **kwargs):\n    model = _timm_create_vision_transformer(*args, **kwargs)\n    _patch_layer_scale(model)\n    return model\n\n\ndef _patch_layer_scale(model: VisionTransformer):\n    def replace_ls(old_ls: TIMMLayerScale):\n        new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)\n        new_ls.load_state_dict(old_ls.state_dict())\n        return new_ls\n\n    # Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name\n    # other than gamma, so that HFHub doesn't mess with it!\n    for mod in model.modules():\n        if isinstance(mod, Block):\n            if isinstance(mod.ls1, TIMMLayerScale):\n                mod.ls1 = replace_ls(mod.ls1)\n            if isinstance(mod.ls2, TIMMLayerScale):\n                mod.ls2 = replace_ls(mod.ls2)\n    pass\n\n\nclass ScaledLayerNorm(nn.LayerNorm):\n    '''\n    https://arxiv.org/pdf/2502.05795v1\n    '''\n    def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):\n        super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)\n        self.load_state_dict(ln_base.state_dict())\n        self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)\n\n    def forward(self, x):\n        y = super().forward(x)\n        y = y * self.ln_scale\n        return y\n\n\nclass DyT(nn.Module):\n    def __init__(self, C: int, init_alpha: float):\n        super().__init__()\n        self.alpha = nn.Parameter(torch.full((1,), init_alpha))\n        self.gamma = nn.Parameter(torch.ones(C))\n        self.beta = nn.Parameter(torch.zeros(C))\n\n    def forward(self, x: torch.Tensor):\n        x = F.tanh(self.alpha * x)\n        return self.gamma * x + self.beta\n\n@register_model\ndef vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:\n    \"\"\" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).\n    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.\n    \"\"\"\n    model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)\n    model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))\n\n    def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):\n        return DyT(ln.normalized_shape[0], init_alpha=0.9)\n    _replace_ln(model, _replace_ln_with_dyt)\n\n    return model\n\n\ndef _apply_scaled_ln(model: VisionTransformer):\n    warnings.warn('Post-LayerNorm scaling activated!')\n\n    _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))\n\ndef _replace_ln(model: VisionTransformer, fn):\n    def _inner_replace_ln(block: Block, depth: int, key: str):\n        prev = getattr(block, key)\n        if isinstance(prev, nn.LayerNorm):\n            setattr(block, key, fn(prev, depth=depth))\n\n    for i, block in enumerate(model.blocks):\n        _inner_replace_ln(block, i + 1, 'norm1')\n        _inner_replace_ln(block, i + 1, 'norm2')\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/feature_normalizer.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom collections import namedtuple\nfrom typing import NamedTuple, Optional, Tuple\nimport torch\nfrom torch import nn\n\n\ndef _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):\n    if x.ndim <= 3:\n        x = x - mean\n        x = x @ tx.T\n    elif x.ndim == 4:\n        x = x - mean.reshape(1, -1, 1, 1)\n        kernel = tx.reshape(*tx.shape, 1, 1)\n        x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0)\n    else:\n        raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}')\n    return x\n\n\nclass FeatureNormalizer(nn.Module):\n    def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):\n        super().__init__()\n\n        self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))\n        self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = _run_kernel(x, self.mean, self.tx)\n        return x\n\n\nclass InterFeatState(NamedTuple):\n    y: torch.Tensor\n    alpha: torch.Tensor\n\n\nclass IntermediateFeatureNormalizerBase(nn.Module):\n    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:\n        raise NotImplementedError()\n\n\nclass IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):\n    def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):\n        super().__init__()\n        self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))\n\n        rot = torch.eye(embed_dim, dtype=dtype)\n        if rot_per_layer:\n            rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)\n\n        self.register_buffer('rotation', rot.contiguous())\n        self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))\n\n    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:\n        if rot_index is None:\n            rot_index = index\n\n        if skip:\n            assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\\'t 3-dimensional.'\n            prefix, x = x[:, :skip], x[:, skip:]\n\n        rotation = self._get_rotation(rot_index)\n        y = _run_kernel(x, self.means[index], rotation)\n\n        alpha = self.alphas[index]\n        if skip:\n            alpha = torch.cat([\n                torch.ones(skip, dtype=alpha.dtype, device=alpha.device),\n                alpha[None].expand(y.shape[1]),\n            ]).reshape(1, -1, 1)\n            y = torch.cat([prefix, y], dim=1)\n        else:\n            if x.ndim == 3:\n                alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1)\n            elif x.ndim == 4:\n                alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:])\n            else:\n                raise ValueError(f'Unsupported input dimension: {x.ndim}')\n\n        return InterFeatState(y, alpha)\n\n    def _get_rotation(self, rot_index: int) -> torch.Tensor:\n        if self.rotation.ndim == 2:\n            return self.rotation\n        return self.rotation[rot_index]\n\n\nclass NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):\n    instances = dict()\n\n    def __init__(self, dtype: torch.dtype, device: torch.device):\n        super().__init__()\n        self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device))\n\n    @staticmethod\n    def get_instance(dtype: torch.dtype, device: torch.device):\n        instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None)\n        if instance is None:\n            instance = NullIntermediateFeatureNormalizer(dtype, device)\n            NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance\n        return instance\n\n    def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:\n        return InterFeatState(x, self.alpha)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/forward_intermediates.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable\nfrom types import MethodType\n\nimport torch\nfrom torch import nn\n\nfrom .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer\n\n\ndef _take_indices(\n        num_blocks: int,\n        n: Optional[Union[int, List[int], Tuple[int]]],\n) -> Tuple[Set[int], int]:\n    if isinstance(n, int):\n        assert n >= 0\n        take_indices = {x for x in range(num_blocks - n, num_blocks)}\n    else:\n        take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}\n    return take_indices, max(take_indices)\n\n\ndef forward_intermediates(\n        model: nn.Module,\n        patch_extractor: Callable[[torch.Tensor], torch.Tensor],\n        norm: nn.Module,\n        num_summary_tokens: int,\n        num_cls_tokens: int,\n        x: torch.Tensor,\n        indices: Optional[Union[int, List[int], Tuple[int]]] = None,\n        return_prefix_tokens: bool = False,\n        stop_early: bool = False,\n        output_fmt: str = 'NCHW',\n        intermediates_only: bool = False,\n        aggregation: Optional[str] = \"sparse\",\n        inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,\n        norm_alpha_scheme = \"post-alpha\",\n        block_kwargs: Dict = None,\n) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:\n    \"\"\" Forward features that returns intermediates.\n\n    The Dense layer aggregation method is inspired from the paper: \"Dense Connector for MLLMs\"\n    by Yao, Huanjin et al. (2024). arXiv preprint arXiv:2405.13800}\n\n    Args:\n        x: Input image tensor\n        indices: Take last n blocks if int, select matching indices if sequence\n        return_prefix_tokens: Return both prefix and spatial intermediate tokens\n        norm: Apply norm layer to all intermediates\n        stop_early: Stop iterating over blocks when last desired intermediate hit\n        output_fmt: Shape of intermediate feature outputs\n        intermediates_only: Only return intermediate features\n        aggregation: intermediate layer aggregation method (sparse or dense)\n        norm_alpha_scheme: apply alpha before (\"pre-alpha\") or after accumulation (\"post-alpha\")\n    Returns:\n    \"\"\"\n    assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'\n    assert aggregation in ('sparse', 'dense'), 'Aggregation must be one of sparse or dense.'\n    reshape = output_fmt == 'NCHW'\n    intermediates = []\n\n    block_kwargs = block_kwargs or dict()\n\n    blocks = model.blocks\n\n    take_indices, max_index = _take_indices(len(blocks), indices)\n    take_indices = sorted(take_indices)\n    # forward pass\n    B, _, height, width = x.shape\n\n    x = patch_extractor(x)\n\n    if stop_early:\n        blocks = blocks[:max_index + 1]\n\n    if inter_feature_normalizer is None or norm_alpha_scheme == 'none':\n        inter_feature_normalizer = NullIntermediateFeatureNormalizer.get_instance(x.dtype, x.device)\n\n    assert norm_alpha_scheme in ('none', 'pre-alpha', 'post-alpha'), f'Unsupported alpha scheme: {norm_alpha_scheme}'\n    post_alpha_scheme = norm_alpha_scheme == 'post-alpha'\n\n    accumulator = 0\n    alpha_sum = 0\n    num_accumulated = 0\n\n    take_off = 0\n\n    for i, blk in enumerate(blocks):\n        x = blk(x, **block_kwargs)\n        if aggregation == \"dense\":\n            # Arbitrarily use the rotation matrix from the final layer in the dense group\n            y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)\n            if post_alpha_scheme:\n                accumulator = accumulator + y\n                alpha_sum = alpha_sum + alpha\n            else:\n                accumulator = accumulator + (alpha * y)\n                alpha_sum += 1\n            num_accumulated += 1\n        if i == take_indices[take_off]:\n            if aggregation == \"dense\":\n                alpha = alpha_sum / num_accumulated\n                x_ = alpha * accumulator / num_accumulated\n                num_accumulated = 0\n                accumulator = 0\n                alpha_sum = 0\n            else:\n                 y, alpha = inter_feature_normalizer(x, i, skip=num_summary_tokens)\n                 x_ = alpha * y\n            # normalize intermediates with final norm layer if enabled\n            intermediates.append(norm(x_))\n            take_off = min(take_off + 1, len(take_indices) - 1)\n\n    # process intermediates\n\n    # split prefix (e.g. class, distill) and spatial feature tokens\n    prefix_tokens = [y[:, :num_cls_tokens] for y in intermediates]\n    intermediates = [y[:, num_summary_tokens:] for y in intermediates]\n\n    if reshape:\n        # reshape to BCHW output format\n        H = height // model.patch_size\n        W = width // model.patch_size\n        intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]\n    if not torch.jit.is_scripting() and return_prefix_tokens:\n        # return_prefix not support in torchscript due to poor type handling\n        intermediates = list(zip(prefix_tokens, intermediates))\n    if intermediates_only:\n        return intermediates\n    x = norm(x)\n    return x, intermediates\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/hf_model.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom collections import namedtuple\nfrom typing import Callable, Dict, Optional, List, Union\n\nfrom timm.models import VisionTransformer\nimport torch\nfrom torch import nn\nfrom transformers import PretrainedConfig, PreTrainedModel\n\n\nfrom .common import RESOURCE_MAP, DEFAULT_VERSION\n\n# Import all required modules.\nfrom .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput\nfrom .adaptor_generic import GenericAdaptor, AdaptorBase\nfrom .adaptor_mlp import create_mlp_from_config\nfrom .adaptor_registry import adaptor_registry\nfrom .cls_token import ClsToken\nfrom .dinov2_arch import dinov2_vitg14_reg\nfrom .enable_cpe_support import enable_cpe\nfrom .enable_spectral_reparam import configure_spectral_reparam_from_args\nfrom .eradio_model import eradio\nfrom .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer\nfrom .forward_intermediates import forward_intermediates\nfrom .radio_model import create_model_from_args\nfrom .radio_model import RADIOModel as RADIOModelBase, Resolution\nfrom .input_conditioner import get_default_conditioner, InputConditioner\nfrom .open_clip_adaptor import OpenCLIP_RADIO\nfrom .vit_patch_generator import ViTPatchGenerator\nfrom .vitdet import apply_vitdet_arch, VitDetArgs\n\n# Register extra models\nfrom .extra_timm_models import *\nfrom .extra_models import *\n\n\nclass RADIOConfig(PretrainedConfig):\n    \"\"\"Pretrained Hugging Face configuration for RADIO models.\"\"\"\n\n    def __init__(\n        self,\n        args: Optional[dict] = None,\n        version: Optional[str] = DEFAULT_VERSION,\n        patch_size: Optional[int] = None,\n        max_resolution: Optional[int] = None,\n        preferred_resolution: Optional[Resolution] = None,\n        adaptor_names: Union[str, List[str]] = None,\n        adaptor_configs: Dict[str, Dict[str, int]] = None,\n        vitdet_window_size: Optional[int] = None,\n        feature_normalizer_config: Optional[dict] = None,\n        inter_feature_normalizer_config: Optional[dict] = None,\n        **kwargs,\n    ):\n        self.args = args\n        for field in [\"dtype\", \"amp_dtype\"]:\n            if self.args is not None and field in self.args:\n                # Convert to a string in order to make it serializable.\n                # For example for torch.float32 we will store \"float32\",\n                # for \"bfloat16\" we will store \"bfloat16\".\n                self.args[field] = str(args[field]).split(\".\")[-1]\n        self.version = version\n        resource = RESOURCE_MAP[version]\n        self.patch_size = patch_size or resource.patch_size\n        self.max_resolution = max_resolution or resource.max_resolution\n        self.preferred_resolution = (\n            preferred_resolution or resource.preferred_resolution\n        )\n        self.adaptor_names = adaptor_names\n        self.adaptor_configs = adaptor_configs\n        self.vitdet_window_size = vitdet_window_size\n        self.feature_normalizer_config = feature_normalizer_config\n        self.inter_feature_normalizer_config = inter_feature_normalizer_config\n        super().__init__(**kwargs)\n\n\n\nclass RADIOModel(PreTrainedModel):\n    \"\"\"Pretrained Hugging Face model for RADIO.\n\n    This class inherits from PreTrainedModel, which provides\n    HuggingFace's functionality for loading and saving models.\n    \"\"\"\n\n    config_class = RADIOConfig\n\n    def __init__(self, config: RADIOConfig):\n        super().__init__(config)\n\n        RADIOArgs = namedtuple(\"RADIOArgs\", config.args.keys())\n        args = RADIOArgs(**config.args)\n        self.config = config\n\n        model = create_model_from_args(args)\n        input_conditioner: InputConditioner = get_default_conditioner()\n\n        dtype = getattr(args, \"dtype\", torch.float32)\n        if isinstance(dtype, str):\n            # Convert the dtype's string representation back to a dtype.\n            dtype = getattr(torch, dtype)\n        model.to(dtype=dtype)\n        input_conditioner.dtype = dtype\n\n        summary_idxs = torch.tensor(\n            [i for i, t in enumerate(args.teachers) if t.get(\"use_summary\", True)],\n            dtype=torch.int64,\n        )\n\n        adaptor_configs = config.adaptor_configs\n        adaptor_names = config.adaptor_names or []\n\n        adaptors = dict()\n        for adaptor_name in adaptor_names:\n            mlp_config = adaptor_configs[adaptor_name]\n            adaptor = GenericAdaptor(args, None, None, mlp_config)\n            adaptor.head_idx = mlp_config[\"head_idx\"]\n            adaptors[adaptor_name] = adaptor\n\n        feature_normalizer = None\n        if config.feature_normalizer_config is not None:\n            # Actual normalization values will be restored when loading checkpoint weights.\n            feature_normalizer = FeatureNormalizer(config.feature_normalizer_config[\"embed_dim\"])\n\n        inter_feature_normalizer = None\n        if config.inter_feature_normalizer_config is not None:\n            inter_feature_normalizer = IntermediateFeatureNormalizer(\n                config.inter_feature_normalizer_config[\"num_intermediates\"],\n                config.inter_feature_normalizer_config[\"embed_dim\"],\n                rot_per_layer=config.inter_feature_normalizer_config[\"rot_per_layer\"],\n                dtype=dtype)\n\n        self.radio_model = RADIOModelBase(\n            model,\n            input_conditioner,\n            summary_idxs=summary_idxs,\n            patch_size=config.patch_size,\n            max_resolution=config.max_resolution,\n            window_size=config.vitdet_window_size,\n            preferred_resolution=config.preferred_resolution,\n            adaptors=adaptors,\n            feature_normalizer=feature_normalizer,\n            inter_feature_normalizer=inter_feature_normalizer,\n        )\n\n    @property\n    def adaptors(self) -> nn.ModuleDict:\n        return self.radio_model.adaptors\n\n    @property\n    def model(self) -> VisionTransformer:\n        return self.radio_model.model\n\n    @property\n    def input_conditioner(self) -> InputConditioner:\n        return self.radio_model.input_conditioner\n\n    @property\n    def num_summary_tokens(self) -> int:\n        return self.radio_model.num_summary_tokens\n\n    @property\n    def patch_size(self) -> int:\n        return self.radio_model.patch_size\n\n    @property\n    def max_resolution(self) -> int:\n        return self.radio_model.max_resolution\n\n    @property\n    def preferred_resolution(self) -> Resolution:\n        return self.radio_model.preferred_resolution\n\n    @property\n    def window_size(self) -> int:\n        return self.radio_model.window_size\n\n    @property\n    def min_resolution_step(self) -> int:\n        return self.radio_model.min_resolution_step\n\n    def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:\n        return self.radio_model.make_preprocessor_external()\n\n    def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:\n        return self.radio_model.get_nearest_supported_resolution(height, width)\n\n    def switch_to_deploy(self):\n        return self.radio_model.switch_to_deploy()\n\n    def forward(self, x: torch.Tensor):\n        return self.radio_model.forward(x)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/input_conditioner.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom typing import Union, Tuple\n\nimport torch\nfrom torch import nn\n\n\nnorm_t = Union[Tuple[float, float, float], torch.Tensor]\n\nclass InputConditioner(nn.Module):\n    def __init__(self,\n                 input_scale: float,\n                 norm_mean: norm_t,\n                 norm_std: norm_t,\n                 dtype: torch.dtype = None,\n    ):\n        super().__init__()\n\n        self.dtype = dtype\n\n        self.register_buffer(\"norm_mean\", _to_tensor(norm_mean) / input_scale)\n        self.register_buffer(\"norm_std\", _to_tensor(norm_std) / input_scale)\n\n    def forward(self, x: torch.Tensor):\n        y = (x - self.norm_mean) / self.norm_std\n        if self.dtype is not None:\n            y = y.to(self.dtype)\n        return y\n\n\ndef get_default_conditioner():\n    from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD\n\n    return InputConditioner(\n        input_scale=1.0,\n        norm_mean=OPENAI_CLIP_MEAN,\n        norm_std=OPENAI_CLIP_STD,\n    )\n\n\ndef _to_tensor(v: norm_t):\n    return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/open_clip_adaptor.py",
    "content": "# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom argparse import Namespace\n\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom .adaptor_registry import adaptor_registry, dict_t, state_t\n\nfrom .adaptor_generic import GenericAdaptor\n\n\nclass OpenCLIP_RADIO(GenericAdaptor):\n    def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):\n        super().__init__(main_config, adaptor_config, state)\n\n        import open_clip\n\n        self.oc_model = open_clip.create_model_from_pretrained(\n            model_name=adaptor_config['model'],\n            pretrained=adaptor_config['pretrained'],\n            return_transform=False,\n        )\n        # Unload these parameters\n        self.oc_model.visual = None\n\n        self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])\n\n    def encode_text(self, text, normalize: bool = False):\n        return self.oc_model.encode_text(text, normalize=normalize)\n\n\n@adaptor_registry.register_adaptor(\"open_clip\")\ndef create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):\n    return OpenCLIP_RADIO(main_config, adaptor_config, state)\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/radio_model.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\nfrom typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\n\nfrom timm.models import create_model, VisionTransformer\nfrom types import MethodType\n\nfrom .enable_cpe_support import enable_cpe\nfrom .input_conditioner import InputConditioner\nfrom .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput\nfrom . import eradio_model\nfrom .enable_spectral_reparam import configure_spectral_reparam_from_args\nfrom .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer\nfrom . import dual_hybrid_vit\n\n\nclass Resolution(NamedTuple):\n    height: int\n    width: int\n\n\nclass RADIOModel(nn.Module):\n    def __init__(\n        self,\n        model: nn.Module,\n        input_conditioner: InputConditioner,\n        patch_size: int,\n        max_resolution: int,\n        preferred_resolution: Resolution,\n        summary_idxs: Optional[torch.Tensor] = None,\n        window_size: int = None,\n        adaptors: Dict[str, AdaptorBase] = None,\n        feature_normalizer: Optional[FeatureNormalizer] = None,\n        inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None,\n    ):\n        super().__init__()\n\n        self.model = model\n        self.input_conditioner = input_conditioner\n        if summary_idxs is not None:\n            self.register_buffer('summary_idxs', summary_idxs)\n        else:\n            self.summary_idxs = None\n\n        self._preferred_resolution = preferred_resolution\n        self._patch_size = patch_size\n        self._max_resolution = max_resolution\n        self._window_size = window_size\n\n        adaptors = adaptors or dict()\n        self.adaptors = nn.ModuleDict(adaptors)\n\n        if feature_normalizer is None:\n            feature_normalizer = nn.Identity()\n        self.feature_normalizer = feature_normalizer\n        self.inter_feature_normalizer = inter_feature_normalizer\n\n    @property\n    def num_summary_tokens(self) -> int:\n        if hasattr(self.model, 'num_summary_tokens'):\n            return self.model.num_summary_tokens\n\n        patch_gen = getattr(self.model, \"patch_generator\", None)\n        if patch_gen is not None:\n            return patch_gen.num_skip\n        elif getattr(self.model, 'global_pool', None) == 'avg':\n            return 0\n        return 1\n\n    @property\n    def num_cls_tokens(self) -> int:\n        if hasattr(self.model, 'num_cls_tokens'):\n            return self.model.num_cls_tokens\n\n        patch_gen = getattr(self.model, 'patch_generator', None)\n        if patch_gen is not None:\n            return patch_gen.num_cls_tokens\n        elif getattr(self.model, 'global_pool', None) == 'avg':\n            return 0\n        return 1\n\n    @property\n    def patch_size(self) -> int:\n        if self._patch_size is not None:\n            return self._patch_size\n        if hasattr(self.model, \"patch_size\"):\n            return self.model.patch_size\n        patch_gen = getattr(self.model, \"patch_generator\", None)\n        if patch_gen is not None:\n            return patch_gen.patch_size\n        return None\n\n    @property\n    def max_resolution(self) -> int:\n        return self._max_resolution\n\n    @property\n    def preferred_resolution(self) -> Resolution:\n        return self._preferred_resolution\n\n    @property\n    def window_size(self) -> int:\n        return self._window_size\n\n    @property\n    def min_resolution_step(self) -> int:\n        res = self.patch_size\n        if self.window_size is not None:\n            res *= self.window_size\n        return res\n\n    @property\n    def blocks(self) -> Iterable[nn.Module]:\n        blocks = getattr(self.model, 'blocks', None)\n        if blocks is not None:\n            return blocks\n        return None\n\n    @property\n    def embed_dim(self) -> int:\n        return self.model.embed_dim\n\n    def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:\n        ret = self.input_conditioner\n        self.input_conditioner = nn.Identity()\n        return ret\n\n    def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:\n        height = int(round(height / self.min_resolution_step) * self.min_resolution_step)\n        width = int(round(width / self.min_resolution_step) * self.min_resolution_step)\n\n        height = max(height, self.min_resolution_step)\n        width = max(width, self.min_resolution_step)\n\n        return Resolution(height=height, width=width)\n\n    def switch_to_deploy(self):\n        fn = getattr(self.model, 'switch_to_deploy', None)\n        if fn is not None:\n            fn()\n\n    def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        '''\n        Forward process for model.\n        Args:\n            x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,\n                             otherwise `x` is expected to be mean centered with unit standard deviation.\n            feature_format: ['NLC', 'NCHW'] - The output format for the features.\n        '''\n        res_step = self.min_resolution_step\n        if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):\n            raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '\n                             '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '\n                             f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')\n\n        x = self.input_conditioner(x)\n        y = self.model.forward_features(x)\n        ret = self._extract_final(x, y, feature_fmt=feature_fmt)\n        return ret\n\n    def forward_pack(self, x: List[torch.Tensor], feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        '''\n        Forward process for model.\n        Args:\n            x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,\n                             otherwise `x` is expected to be mean centered with unit standard deviation.\n            feature_format: ['NLC', 'NCHW'] - The output format for the features.\n        '''\n        res_step = self.min_resolution_step\n        for _x in x:\n            if res_step is not None and (_x.shape[-2] % res_step != 0 or _x.shape[-1] % res_step != 0):\n                raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '\n                                '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '\n                                f'Input: {_x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*_x.shape[-2:])}')\n\n        x = [self.input_conditioner(_x) for _x in x]\n        y, cu_seqlens = self.model.forward_features(x)\n        all_summary, spatial_features = [], []\n        num_cls_tokens = self.model.patch_generator.num_cls_tokens\n        num_skip = self.model.patch_generator.num_skip\n        for i in range(len(cu_seqlens)-1):\n            summary = y[cu_seqlens[i]: cu_seqlens[i+1]][: num_cls_tokens]\n            all_feat = y[cu_seqlens[i]: cu_seqlens[i+1]][num_skip :]\n            all_summary.append(summary)\n            spatial_features.append(all_feat)\n        all_summary = torch.cat(all_summary)\n        spatial_features = torch.cat(spatial_features)\n        return all_summary, spatial_features\n\n    def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'):\n        if isinstance(self.model, VisionTransformer):\n            patch_gen = getattr(self.model, \"patch_generator\", None)\n            if patch_gen is not None:\n                all_summary = y[:, : patch_gen.num_cls_tokens]\n                if self.summary_idxs is not None:\n                    bb_summary = all_summary[:, self.summary_idxs]\n                else:\n                    bb_summary = all_summary\n                all_feat = y[:, patch_gen.num_skip :]\n            elif self.model.global_pool == \"avg\":\n                all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)\n                bb_summary = all_summary\n                all_feat = y\n            else:\n                all_summary = y[:, 0]\n                bb_summary = all_summary\n                all_feat = y[:, 1:]\n        elif isinstance(self.model, eradio_model.ERADIO):\n            _, f = y\n            all_feat = f.flatten(2).transpose(1, 2)\n            all_summary = all_feat.mean(dim=1)\n            bb_summary = all_summary\n        elif isinstance(y, (list, tuple)):\n            all_summary, all_feat = y\n            bb_summary = all_summary\n        else:\n            all_summary = y[:, :self.num_cls_tokens]\n            if self.summary_idxs is not None and all_summary.shape[1] > 1:\n                if all_summary.shape[1] == 1:\n                    # Create dummy duplicates\n                    all_summary = all_summary.expand(-1, 128, -1)\n                bb_summary = all_summary[:, self.summary_idxs]\n            else:\n                bb_summary = all_summary\n            all_feat = y[:, self.num_summary_tokens:]\n\n        all_feat = self.feature_normalizer(all_feat)\n\n        if feature_fmt == 'NCHW':\n            fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2])\n                                .permute(0, 3, 1, 2)\n            )\n        elif feature_fmt == 'NLC':\n            fmt_feat = all_feat\n        else:\n            raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of [\"NLC\", \"NCHW\"]')\n\n        ret = RadioOutput(bb_summary.flatten(1), fmt_feat)\n\n        if self.adaptors:\n            ret = dict(backbone=ret)\n            for name, adaptor in self.adaptors.items():\n                if all_summary.ndim == 3:\n                    if all_summary.shape[1] == 1:\n                        summary = all_summary[:, 0]\n                    else:\n                        summary = all_summary[:, adaptor.head_idx]\n                else:\n                    summary = all_summary\n                ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)\n                v = adaptor(ada_input).to(torch.float32)\n                ret[name] = v\n\n        return ret\n\n    def forward_intermediates(\n            self,\n            x: torch.Tensor,\n            indices: Optional[Union[int, List[int], Tuple[int]]] = None,\n            return_prefix_tokens: bool = False,\n            norm: bool = False,\n            stop_early: bool = False,\n            output_fmt: str = 'NCHW',\n            intermediates_only: bool = False,\n            aggregation: Optional[str] = \"sparse\",\n            norm_alpha_scheme: Optional[str] = \"post-alpha\",\n    ) -> List[RadioOutput]:\n        \"\"\" Forward features that returns intermediates.\n        Args:\n            x: Input image tensor\n            indices: Take last n blocks if int, select matching indices if sequence\n            return_prefix_tokens: Return both prefix and spatial intermediate tokens\n            norm: Apply norm layer to all intermediates\n            stop_early: Stop iterating over blocks when last desired intermediate hit\n            output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC\n            intermediates_only: Only return intermediate features\n            aggregation: intermediate layer aggregation method (sparse or dense).\n                Dense accumulation is done by averaging the features in each group.\n            norm_alpha_scheme: apply alpha before (\"pre-alpha\") or after accumulation (\"post-alpha\"), or don't normalize (\"none\")\n                Only affects dense aggregation\n        Returns:\n            List of RadioOutput objects.\n        \"\"\"\n        x = self.input_conditioner(x)\n        intermediates = self.model.forward_intermediates(\n            x,\n            indices=indices,\n            return_prefix_tokens=return_prefix_tokens,\n            norm=norm,\n            stop_early=stop_early,\n            output_fmt=output_fmt,\n            intermediates_only=intermediates_only,\n            aggregation=aggregation,\n            inter_feature_normalizer=self.inter_feature_normalizer,\n            norm_alpha_scheme=norm_alpha_scheme,\n        )\n\n        if not intermediates_only:\n            final, intermediates = intermediates\n\n        def prepare_summary(summ: Optional[torch.Tensor]):\n            if summ is None:\n                return summ\n            if self.summary_idxs is not None and summ.shape[1] > 1:\n                summ = summ[:, self.summary_idxs]\n            return summ.flatten(1)\n\n        if return_prefix_tokens:\n            radio_outputs = [\n                RadioOutput(prepare_summary(summary), features)\n                for summary, features in intermediates\n            ]\n        else:\n            radio_outputs = intermediates\n\n        if intermediates_only:\n            return radio_outputs\n        else:\n            final = self._extract_final(x, final, feature_fmt=output_fmt)\n            return final, radio_outputs\n\n\n\ndef create_model_from_args(args) -> nn.Module:\n    in_chans = 3\n    if args.in_chans is not None:\n        in_chans = args.in_chans\n    elif args.input_size is not None:\n        in_chans = args.input_size[0]\n\n    # Skip weight initialization unless it's explicitly requested.\n    weight_init = args.model_kwargs.pop(\"weight_init\", \"skip\")\n\n    model = create_model(\n        args.model,\n        pretrained=args.pretrained,\n        in_chans=in_chans,\n        num_classes=args.num_classes,\n        drop_rate=args.drop,\n        drop_path_rate=args.drop_path,\n        drop_block_rate=args.drop_block,\n        global_pool=args.gp,\n        bn_momentum=args.bn_momentum,\n        bn_eps=args.bn_eps,\n        scriptable=args.torchscript,\n        checkpoint_path=args.initial_checkpoint,\n        weight_init=weight_init,\n        **args.model_kwargs,\n    )\n\n    if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):\n        model.norm = nn.Identity()\n\n    model.head = nn.Identity()\n\n    if args.cpe_max_size is not None:\n        uq_teachers = set(t['name'] for t in args.teachers)\n        enable_cpe(\n            model,\n            args.cpe_max_size,\n            num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1,\n            register_multiple=getattr(args, 'register_multiple', None),\n            num_registers=getattr(args, 'cpe_num_registers', None),\n            support_packing=args.support_packing,\n        )\n    \n    return model\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vision_transformer_xpos.py",
    "content": "import math\nfrom typing import Final, List, Optional, Tuple, Union\n\n\nfrom einops import rearrange\nfrom timm.models import register_model\nimport torch\nfrom torch import Type, nn\nfrom torch.nn import functional as F\nfrom torch.nn.init import xavier_normal_, xavier_uniform_, zeros_\n\nfrom .forward_intermediates import forward_intermediates\n\n\ndef _get_init_scale(num_encoder_layers: int, num_decoder_layers: int, is_encoder: bool):\n    if num_encoder_layers > 0 and num_decoder_layers == 0:\n        return math.sqrt(math.log(2 * num_encoder_layers))\n    if num_decoder_layers > 0 and num_encoder_layers == 0:\n        return math.sqrt(math.log(2 * num_decoder_layers))\n    if is_encoder:\n        # Both encoders and decoders\n        return math.sqrt(\n            0.33 * math.log(3 * num_decoder_layers) * math.log(2 * num_encoder_layers)\n        )\n\n    return math.sqrt(math.log(3 * num_decoder_layers))\n\n\n# [1,2]    [1,1,2,2]\n# [3,4] -> [3,3,4,4]\n# [5,6]    [5,5,6,6]\ndef duplicate_interleave(m):\n    return m.view(-1, 1).repeat(1, 2).view(m.shape[0], -1)\n\n# 0,1,2,3,4,5,6,7 -> -1,0,-3,2,-5,4,-7,6\ndef rotate_every_two(x):\n    x1 = x[:, :, ::2]\n    x2 = x[:, :, 1::2]\n    x = torch.stack((-x2, x1), dim=-1)\n    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\\\n\n\nclass XPosEmbedding2D(torch.nn.Module):\n    \"\"\"Implementation of xPos based on RotaryEmbedding from GPT-NeoX.\n    This implementation is designed to operate on queries and keys that are compatible with\n    [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).\n    \"\"\"\n\n    def __init__(\n        self,\n        head_dim: int,\n        base=50000,\n        scale_base=512\n    ):\n        super().__init__()\n        half_dim = head_dim // 2\n        self.half_dim = half_dim\n        inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2).float() / half_dim))\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.head_dim = head_dim\n        self.token_shape_cached = None\n        self.batch_size_cached = None\n        self.cos_cached: torch.Tensor | None = None\n        self.sin_cached: torch.Tensor | None = None\n        self.scale_cached: torch.Tensor | None = None\n        self.scale_base = scale_base\n        self.register_buffer(\"scale\",\n                             (torch.arange(0, half_dim, 2) + 0.4 * half_dim) / (1.4 * half_dim))\n\n    def cos_sin(\n        self,\n        token_shape: Tuple[int, int],\n        device=\"cuda\",\n        dtype=torch.bfloat16,\n    ) -> torch.Tensor:\n        if token_shape != self.token_shape_cached:\n            self.token_shape_cached = token_shape\n            y = torch.arange(token_shape[0], device=device, dtype=self.inv_freq.dtype)\n            x = torch.arange(token_shape[1], device=device, dtype=self.inv_freq.dtype)\n            x, y = torch.meshgrid(x, y, indexing='xy')\n\n            y_freqs = torch.einsum(\"i,j->ij\", y.flatten(), self.inv_freq)\n            x_freqs = torch.einsum(\"i,j->ij\", x.flatten(), self.inv_freq)\n\n            y_scales = self.scale ** y.flatten().div(self.scale_base)[:, None]\n            x_scales = self.scale ** x.flatten().div(self.scale_base)[:, None]\n\n            freqs = torch.cat([y_freqs, x_freqs], dim=-1)\n            emb = torch.repeat_interleave(freqs, repeats=2, dim=-1)\n\n            scales = torch.cat([y_scales, x_scales], dim=-1)\n            scales = torch.repeat_interleave(scales, repeats=2, dim=-1)\n\n            if dtype in [torch.float16, torch.bfloat16]:\n                emb = emb.float()\n\n            self.cos_cached = emb.cos()[None, :, :]\n            self.sin_cached = emb.sin()[None, :, :]\n            self.scale_cached = scales[None, :, :]\n\n            self.cos_cached = self.cos_cached.type(dtype)\n            self.sin_cached = self.sin_cached.type(dtype)\n            self.scale_cached = self.scale_cached.type(dtype)\n\n        return self.cos_cached, self.sin_cached, self.scale_cached\n\n    def forward(self, q: torch.Tensor, k: torch.Tensor, token_shape: Tuple[int, int]):\n        batch, seq_len, head_dim = q.shape\n        cos, sin, scale = self.cos_sin(token_shape, q.device, q.dtype)\n        # scale = self.scale**torch.arange(seq_len).to(self.scale).div(self.scale_base)[:, None]\n        # scale = torch.repeat_interleave(scale, 2, dim=-1).to(q.device)\n        # scale = torch.cat([scale, scale], dim=-1)\n        # scale = 1\n        return (\n            (q * cos * scale) + (rotate_every_two(q) * sin * scale),\n            (k * cos * (1 / scale)) + (rotate_every_two(k) * sin * (1 / scale)),\n        )\n\n\nclass MagnetoAttention(nn.Module):\n    def __init__(self, d_model: int, n_head: int, pos_emb: XPosEmbedding2D):\n        super().__init__()\n        self.num_heads = n_head\n        self.head_dim = d_model // n_head\n        self.scale = self.head_dim ** -0.5\n\n        self.qkv = nn.Linear(d_model, d_model * 3, bias=False)\n        self.proj = nn.Linear(d_model, d_model)\n        self.pos_emb = pos_emb\n\n        self.norm0 = nn.LayerNorm(d_model)\n        self.norm1 = nn.LayerNorm(d_model)\n\n    def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:\n        B, N, C = x.shape\n        x = self.norm0(x)\n\n        qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)\n        q, k, v = qkv.unbind(0)\n\n        q_pref = q[:, :num_prefix_tokens]\n        q_patch = q[:, num_prefix_tokens:]\n\n        k_pref = k[:, :num_prefix_tokens]\n        k_patch = k[:, num_prefix_tokens:]\n\n        q_patch, k_patch = self.pos_emb(q_patch, k_patch, patch_shape)\n\n        q = torch.cat([q_pref, q_patch], dim=1)\n        k = torch.cat([k_pref, k_patch], dim=1)\n\n        def head_reshape(t: torch.Tensor):\n            return t.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n\n        q = head_reshape(q)\n        k = head_reshape(k)\n        v = head_reshape(v)\n\n        x = F.scaled_dot_product_attention(q, k, v)\n        x = x.transpose(1, 2).reshape(B, N, C)\n        x = self.norm1(x)\n        x = self.proj(x)\n        return x\n\n    def _reset_parameters(self):\n        xavier_uniform_(self.qkv.weight)\n        if self.qkv.bias is not None:\n            zeros_(self.qkv.bias)\n        xavier_normal_(self.proj.weight)\n        zeros_(self.proj.bias)\n\n\nclass MagnetoTransformerEncoderLayer(nn.Module):\n    def __init__(self, d_model: int, nhead: int, pos_emb: XPosEmbedding2D,\n                 num_encoder_layers: int, num_decoder_layers: int = 0,\n                 dim_mhsa: int = 0,\n                 dim_feedforward: int = 2048,\n                 layer_norm_eps: float = 1e-5,\n                 batch_first: bool = True):\n        super().__init__()\n\n        if dim_mhsa == 0:\n            dim_mhsa = d_model\n\n        self._num_encoder_layers = num_encoder_layers\n        self._num_decoder_layers = num_decoder_layers\n\n        self.attn = MagnetoAttention(d_model, nhead, pos_emb)\n\n        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)\n        self.linear2 = nn.Linear(d_model, dim_feedforward)\n        self.norm3 = nn.LayerNorm(dim_feedforward, eps=layer_norm_eps)\n        self.linear3 = nn.Linear(dim_feedforward, d_model)\n\n    def initialize(self):\n        gamma = _get_init_scale(self._num_encoder_layers, self._num_decoder_layers, is_encoder=True)\n\n        # Magneto Initialization\n        for mod in self.children():\n            if isinstance(mod, nn.Linear):\n                xavier_normal_(mod.weight.data, gamma)\n            elif isinstance(mod, MagnetoAttention):\n                mod._reset_parameters()\n\n    def forward(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:\n        x = x + self._sa_block(x, num_prefix_tokens, patch_shape)\n        x = x + self._ff_block(x)\n        return x\n\n    def _sa_block(self, x: torch.Tensor, num_prefix_tokens: int, patch_shape: Tuple[int, int]) -> torch.Tensor:\n        x = self.attn(x, num_prefix_tokens, patch_shape)\n        return x\n\n    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.norm2(x)\n        x = self.linear2(x)\n        x = F.gelu(x)\n        x = self.norm3(x)\n        x = self.linear3(x)\n        return x\n\n\nclass VisionTransformer(nn.Module):\n    \"\"\" Vision Transformer\n\n    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`\n        - https://arxiv.org/abs/2010.11929\n    \"\"\"\n    dynamic_img_size: Final[bool]\n\n    def __init__(\n            self,\n            patch_size: Union[int, Tuple[int, int]] = 16,\n            in_chans: int = 3,\n            embed_dim: int = 768,\n            depth: int = 12,\n            num_heads: int = 12,\n            mlp_ratio: float = 4.,\n            num_cls_tokens: int = 1,\n            num_reg_tokens: int = 0,\n    ) -> None:\n        \"\"\"\n        Args:\n            patch_size: Patch size.\n            in_chans: Number of image input channels.\n            embed_dim: Transformer embedding dimension.\n            depth: Depth of transformer.\n            num_heads: Number of attention heads.\n            mlp_ratio: Ratio of mlp hidden dim to embedding dim.\n            num_cls_tokens: Number of cls tokens\n            num_reg_tokens: Number of register tokens.\n            block_fn: Transformer block layer.\n        \"\"\"\n        super().__init__()\n\n        self.patch_size = patch_size\n        self.embed_dim = embed_dim\n        self.num_cls_tokens = num_cls_tokens\n        self.num_reg_tokens = num_reg_tokens\n\n        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        self.prefix_buffer = nn.Parameter(torch.randn(1, self.num_prefix_tokens, embed_dim) * .02)\n\n        pos_emb = XPosEmbedding2D(embed_dim)\n\n        self.blocks = nn.ModuleList([\n            MagnetoTransformerEncoderLayer(\n                d_model=embed_dim,\n                nhead=num_heads,\n                num_encoder_layers=depth,\n                num_decoder_layers=0,\n                dim_feedforward=int(embed_dim * mlp_ratio),\n                pos_emb=pos_emb,\n            )\n            for _ in range(depth)\n        ])\n\n        for block in self.blocks:\n            block.initialize()\n\n    @property\n    def num_prefix_tokens(self):\n        return self.num_cls_tokens + self.num_reg_tokens\n\n    @property\n    def num_summary_tokens(self):\n        return self.num_prefix_tokens\n\n    def forward_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        x, patch_shape = self._patchify(x)\n\n        for block in self.blocks:\n            x = block(x, self.num_prefix_tokens, patch_shape)\n\n        summary = x[:, :self.num_cls_tokens]\n        features = x[:, self.num_prefix_tokens:]\n\n        return summary, features\n\n    def forward_intermediates(self, x: torch.Tensor, norm: bool = False, **kwargs):\n        patch_shape = tuple(d // self.patch_size for d in x.shape[-2:])\n\n        def patch_extractor(x: torch.Tensor):\n            x, _ = self._patchify(x)\n            return x\n\n        return forward_intermediates(\n            self,\n            patch_extractor=patch_extractor,\n            num_summary_tokens=self.num_prefix_tokens,\n            num_cls_tokens=self.num_cls_tokens,\n            norm=lambda y: y,\n            x=x,\n            block_kwargs=dict(num_prefix_tokens=self.num_prefix_tokens, patch_shape=patch_shape),\n            **kwargs,\n        )\n\n    def _patchify(self, x: torch.Tensor):\n        x = self.patch_embed(x)\n        patch_shape = x.shape[-2:]\n        x = rearrange(x, 'b c h w -> b (h w) c')\n\n        prefix = self.prefix_buffer.expand(x.shape[0], -1, -1)\n\n        x = torch.cat([prefix, x], dim=1)\n        return x, patch_shape\n\n\n@register_model\ndef vit_base_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:\n    return VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12,\n                             num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)\n\n\n@register_model\ndef vit_large_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:\n    return VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16,\n                             num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)\n\n\n@register_model\ndef vit_huge_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:\n    return VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16,\n                             num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)\n\n\n@register_model\ndef vit_giant_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:\n    return VisionTransformer(patch_size=16, embed_dim=1408, depth=40, num_heads=16,\n                             num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)\n\n\n@register_model\ndef vit_bigG_patch16_xpos(num_cls_tokens: int = 1, num_reg_tokens: int = 0, **kwargs) -> VisionTransformer:\n    return VisionTransformer(patch_size=16, embed_dim=1664, depth=48, num_heads=16,\n                             num_cls_tokens=num_cls_tokens, num_reg_tokens=num_reg_tokens)"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vit_patch_generator.py",
    "content": "# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport math\nfrom typing import Union, Tuple, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom einops import rearrange\n\nfrom .cls_token import ClsToken\n\ninput_dim_t = Union[int, Tuple[int, int]]\n\ntry:\n    # raise ImportError()\n    from indirect_grid_sample import indirect_grid_sample\nexcept ImportError:\n    indirect_grid_sample = None\n\nclass ViTPatchGenerator(nn.Module):\n    def __init__(self,\n                 patch_size: int,\n                 embed_dim: int,\n                 input_dims: input_dim_t,\n                 abs_pos: bool = True,\n                 normalize_patches: bool = False,\n                 cls_token: bool = False,\n                 max_input_dims: Optional[input_dim_t] = None,\n                 pos_dropout: float = 0.0,\n                 return_pos_enc: bool = False,\n                 num_cls_tokens: int = 1,\n                 register_multiple: Optional[int] = None,\n                 num_registers: Optional[int] = None,\n                 patch_bias: bool = False,\n                 device=None, dtype=None,\n    ):\n        super().__init__()\n\n        if isinstance(input_dims, int):\n            input_dims = (input_dims, input_dims)\n\n        if max_input_dims is None:\n            max_input_dims = input_dims\n        if isinstance(max_input_dims, int):\n            max_input_dims = (max_input_dims, max_input_dims)\n\n        max_input_dims = tuple(\n            int(math.ceil(d / patch_size) * patch_size)\n            for d in max_input_dims\n        )\n\n        self.cpe_mode = max_input_dims != input_dims\n        self.pos_dropout = pos_dropout\n        self.return_pos_enc = return_pos_enc\n\n        factory = dict(device=device, dtype=dtype)\n\n        self.patch_size = patch_size\n        self.abs_pos = abs_pos\n        self.embed_dim = embed_dim\n\n        self.num_rows = max_input_dims[0] // patch_size\n        self.num_cols = max_input_dims[1] // patch_size\n        self.input_dims = tuple(d // patch_size for d in input_dims)\n        self.num_patches = self.num_rows * self.num_cols\n        self.max_input_dims = max_input_dims\n\n        self.im_to_patches = Im2Patches(patch_size)\n        self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias, **factory)\n\n        if abs_pos:\n            scale = embed_dim ** -0.5\n            self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)\n\n        self.cls_token = ClsToken(\n            embed_dim,\n            num_tokens=num_cls_tokens,\n            enabled=cls_token,\n            register_multiple=register_multiple,\n            num_registers=num_registers,\n        )\n\n        self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        patches = self.embed_patches(x)\n        patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])\n        patches = self.cls_token(patches)\n        patches = self.patch_normalizer(patches)\n        if self.return_pos_enc:\n            return patches, pos_enc\n        return patches\n\n    @property\n    def apply_cls_token(self):\n        return self.cls_token.enabled\n\n    @property\n    def num_cls_tokens(self):\n        return self.cls_token.num_tokens\n\n    @property\n    def num_cls_patches(self):\n        return self.cls_token.num_patches\n\n    @property\n    def num_registers(self):\n        return self.cls_token.num_registers\n\n    @property\n    def num_skip(self):\n        return self.num_cls_tokens + self.num_registers\n\n    def no_weight_decay(self):\n        return [\n            'pos_embed',\n        ]\n\n    def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):\n        if src_embed.shape != targ_embed.shape:\n            src_size = int(math.sqrt(src_embed.shape[1]))\n\n            assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'\n\n            src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)\n            src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)\n            src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')\n        targ_embed.data.copy_(src_embed)\n\n    def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):\n        if src_proj_weight.shape != targ_proj_weight.shape:\n            src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))\n\n            assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'\n\n            src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)\n            src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)\n            src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')\n        targ_proj_weight.data.copy_(src_proj_weight)\n\n    def embed_patches(self, x: torch.Tensor) -> torch.Tensor:\n        patches = self.im_to_patches(x)\n        patches = self.embedder(patches)\n        return patches\n\n    def apply_pos_enc(self,\n                      patches: torch.Tensor,\n                      patch_idxs: Optional[torch.Tensor] = None,\n                      input_size: Optional[Tuple[int, int]] = None,\n    ) -> torch.Tensor:\n        if not self.abs_pos:\n            return patches\n\n        pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)\n\n        if self.training and self.pos_dropout > 0:\n            keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout\n            pos_enc_drop = torch.where(keeps, pos_enc, 0)\n        else:\n            pos_enc_drop = pos_enc\n\n        return patches + pos_enc_drop, pos_enc\n\n    def get_pos_enc(self,\n                    batch_size: int,\n                    patch_idxs: Optional[torch.Tensor] = None,\n                    input_size: Optional[Tuple[int, int]] = None,\n    ) -> torch.Tensor:\n        if input_size is None:\n            input_dims = self.input_dims\n        else:\n            input_dims = tuple(d // self.patch_size for d in input_size)\n\n        pos_embed = self._get_pos_embeddings(batch_size, input_dims)\n\n        if patch_idxs is None:\n            return pos_embed\n\n        exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])\n\n        pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)\n        return pos_embed\n\n\n    def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):\n        if (self.num_rows, self.num_cols) == input_dims:\n            return self.pos_embed\n\n        pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)\n\n        def window_select(pos_embed):\n            if input_dims[0] < pos_embed.shape[-2]:\n                pos_embed = pos_embed[..., :input_dims[0], :]\n            if input_dims[1] < pos_embed.shape[-1]:\n                pos_embed = pos_embed[..., :, :input_dims[1]]\n            return pos_embed\n\n        if self.cpe_mode:\n            if self.training:\n                min_scale = math.sqrt(0.1)\n                scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale\n                aspect_min = math.log(3 / 4)\n                aspect_max = -aspect_min\n                aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)\n\n                scale_x = scale * aspect\n                scale_y = scale * (1 / aspect)\n                scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)\n\n                pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)\n\n                lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)\n                lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])\n\n                lin_xy = torch.stack([lin_x, lin_y], dim=-1)\n\n                grid_xy = lin_xy * scale_xy + pos_xy\n\n                # Convert to [-1, 1] range\n                grid_xy.mul_(2).sub_(1)\n\n                pos_embed = F.grid_sample(\n                    pos_embed.float().expand(batch_size, -1, -1, -1),\n                    grid=grid_xy,\n                    mode='bilinear',\n                    padding_mode='zeros',\n                    align_corners=True,\n                ).to(pos_embed.dtype)\n            else:\n                # i_rows, i_cols = input_dims\n                # p_rows, p_cols = pos_embed.shape[2:]\n                # if i_rows <= p_rows and i_cols <= p_cols:\n                #     left = (p_cols - i_cols) // 2\n                #     top = (p_rows - i_rows) // 2\n                #     pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]\n                # else:\n                max_dim = max(input_dims)\n                pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)\n\n                pos_embed = window_select(pos_embed)\n        else:\n            pos_embed = window_select(pos_embed)\n\n        if pos_embed.shape[-2:] != input_dims:\n            pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)\n\n        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)\n\n        return pos_embed\n\n\nclass Im2Patches(nn.Module):\n    def __init__(self, patch_size: int):\n        super().__init__()\n        self.patch_size = patch_size\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.patch_size == 1:\n            patches = x.flatten(2)\n            patches = patches.permute(0, 2, 1)\n            return patches\n\n        py = x.shape[-2] // self.patch_size\n        px = x.shape[-1] // self.patch_size\n        patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',\n                            py=py, yy=self.patch_size,\n                            px=px, xx=self.patch_size,\n        )\n        return patches\n\n\nclass ViTPatchLinear(nn.Linear):\n    def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):\n        super().__init__(\n            3 * (patch_size ** 2),\n            embed_dim,\n            bias=bias,\n            **factory\n        )\n        self.patch_size = patch_size\n"
  },
  {
    "path": "nit/models/nvidia_radio/radio/vitdet.py",
    "content": "from collections import defaultdict\nfrom contextlib import contextmanager\nfrom logging import getLogger\nimport math\nimport sys\nfrom typing import List, Union, Iterable\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom timm.models import VisionTransformer\nfrom einops import rearrange\n\nfrom .extra_models import DinoWrapper\n\nDEFAULT_NUM_WINDOWED = 5\nDEFAULT_NUM_GLOBAL = 4\n\n\nclass VitDetArgs:\n    def __init__(self,\n                 window_size: int,\n                 num_summary_tokens: int,\n                 num_windowed: int = None,\n                 num_global: int = None,\n    ):\n        self.window_size = window_size\n        self.num_summary_tokens = num_summary_tokens\n        self.num_windowed = num_windowed\n        self.num_global = num_global\n\n\ndef apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs):\n    if isinstance(model, VisionTransformer):\n        patch_embed = getattr(model, 'patch_generator', model.patch_embed)\n\n        return ViTDetHook(patch_embed, model.blocks, args)\n    elif isinstance(model, DinoWrapper):\n        inner = model.inner\n\n        patch_embed = getattr(inner, 'patch_generator', inner.patch_embed)\n        return ViTDetHook(patch_embed, inner.blocks, args)\n    else:\n        print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)\n\n\nclass ViTDetHook:\n    def __init__(self,\n                 embedder: nn.Module,\n                 blocks: nn.Sequential,\n                 args: VitDetArgs,\n    ):\n        self.blocks = blocks\n        self.num_summary_tokens = args.num_summary_tokens\n        self.window_size = args.window_size\n\n        self._input_resolution = None\n        self._num_windows = None\n        self._cls_patch = None\n        self._order_cache = dict()\n\n        embedder.register_forward_pre_hook(self._enter_model)\n\n        # This will decide if we window-fy the patches\n        # and enable vit-det for this iteration, and if so,\n        # rearrange the patches for efficient mode switching\n        blocks.register_forward_pre_hook(self._enter_blocks)\n\n        is_global = True\n        if args.num_windowed is not None:\n            period = args.num_windowed + 1\n        else:\n            num_global = args.num_global or DEFAULT_NUM_GLOBAL\n            period = max(len(blocks) // num_global, 1)\n\n        for i, layer in enumerate(blocks[:-1]):\n            ctr = i % period\n            if ctr == 0:\n                layer.register_forward_pre_hook(self._to_windows)\n                is_global = False\n            elif ctr == period - 1:\n                layer.register_forward_pre_hook(self._to_global)\n                is_global = True\n\n        # Always ensure the final layer is a global layer\n        if not is_global:\n            blocks[-1].register_forward_pre_hook(self._to_global)\n\n        blocks.register_forward_hook(self._exit_model)\n\n    def _enter_model(self, _, input: List[torch.Tensor]):\n        self._input_resolution = input[0].shape[-2:]\n\n    def _enter_blocks(self, _, input: List[torch.Tensor]):\n        # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)\n\n        patches = input[0]\n        patches = self._rearrange_patches(patches)\n\n        return (patches,) + input[1:]\n\n    def _to_windows(self, _, input: List[torch.Tensor]):\n        patches = input[0]\n\n        if self.num_summary_tokens:\n            self._cls_patch = patches[:, :self.num_summary_tokens]\n            patches = patches[:, self.num_summary_tokens:]\n\n        patches = rearrange(\n            patches, 'b (p t) c -> (b p) t c',\n            p=self._num_windows, t=self.window_size ** 2,\n        )\n\n        return (patches,) + input[1:]\n\n    def _to_global(self, _, input: List[torch.Tensor]):\n        patches = input[0]\n\n        patches = rearrange(\n            patches, '(b p) t c -> b (p t) c',\n            p=self._num_windows, t=self.window_size ** 2,\n            b=patches.shape[0] // self._num_windows,\n        )\n\n        if self.num_summary_tokens:\n            patches = torch.cat([\n                self._cls_patch,\n                patches,\n            ], dim=1)\n\n        return (patches,) + input[1:]\n\n    def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):\n        # Return patches to their original order\n        patch_order = self._order_cache[self._input_resolution][0]\n        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)\n\n        ret_patches = torch.empty_like(patches)\n        ret_patches = torch.scatter(\n            ret_patches,\n            dim=1,\n            index=patch_order,\n            src=patches,\n        )\n\n        return ret_patches\n\n    def _rearrange_patches(self, patches: torch.Tensor):\n        # We rearrange the patches so that we can efficiently\n        # switch between windowed and global mode by just\n        # reshaping the tensor\n\n        patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))\n        if patch_order is None:\n            num_feat_patches = patches.shape[1] - self.num_summary_tokens\n            num_pixels = self._input_resolution[0] * self._input_resolution[1]\n\n            patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))\n            rows = self._input_resolution[-2] // patch_size\n            cols = self._input_resolution[-1] // patch_size\n\n            w_rows = rows // self.window_size\n            w_cols = cols // self.window_size\n\n            patch_order = torch.arange(0, num_feat_patches, device=patches.device)\n\n            patch_order = rearrange(\n                patch_order, '(wy py wx px) -> (wy wx py px)',\n                wy=w_rows, wx=w_cols,\n                py=self.window_size, px=self.window_size,\n            )\n\n            if self.num_summary_tokens:\n                patch_order = torch.cat([\n                    torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),\n                    patch_order + self.num_summary_tokens,\n                ])\n\n            self._num_windows = w_rows * w_cols\n            self._order_cache[self._input_resolution] = (\n                patch_order,\n                self._num_windows,\n            )\n\n        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)\n        patches = torch.gather(patches, dim=1, index=patch_order)\n        return patches\n"
  },
  {
    "path": "nit/models/utils/convs.py",
    "content": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom nit.models.efficientvit.models.nn.ops import ConvLayer\nfrom nit.models.efficientvit.models.nn.act import build_act\nfrom nit.models.efficientvit.models.utils import val2tuple\n\ndef create_conv_1(conv_type, in_channels, out_channels, norm, act_func, groups=1):\n    '''\n    conv_type: dwconv_3x3_1, dsconv_3x3_1, dgconv_3x3_1\n    '''\n    if conv_type == None or conv_type == \"\":\n        return nn.Identity()\n    splited_conv_type = conv_type.split('_')\n    conv_type = splited_conv_type[0]\n    kernel_size = int(splited_conv_type[1].split('x')[0])\n    stride = int(splited_conv_type[2])\n    if conv_type == 'dwconv':\n        return DWConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func)\n    elif conv_type == 'dsconv':\n        return DSConv(in_channels, out_channels, kernel_size, stride, norm=norm, act_func=act_func)\n    elif conv_type == 'dgconv':\n        return DGConv(in_channels, out_channels, kernel_size, stride, groups, norm=norm, act_func=act_func)\n    else:\n        return nn.Identity()\n\n\n\ndef create_conv_2(conv_type, in_channels, out_channels, mid_channels):\n    '''\n    conv_type: mbconv_3x3_1, fusedmbconv_3x3_1, glumbconv_3x3_1\n    '''\n    if conv_type == None or conv_type == \"\":\n        return nn.Identity()\n    splited_conv_type = conv_type.split('_')\n    conv_type = splited_conv_type[0]\n    kernel_size = int(splited_conv_type[1].split('x')[0])\n    stride = int(splited_conv_type[2])\n    if conv_type == 'mbconv':\n        return MBConv(in_channels, out_channels, kernel_size, stride, mid_channels)\n    elif conv_type == 'fusedmbconv':\n        return FusedMBConv(in_channels, out_channels, kernel_size, stride, mid_channels)\n    elif conv_type == 'glumbconv':\n        return GLUMBConv(in_channels, out_channels, kernel_size, stride, mid_channels)\n    else:\n        return nn.Identity()\n\n\n\nclass DWConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        use_bias=True,\n        norm=\"bn2d\",\n        act_func=\"relu6\",\n    ):\n        super(DWConv, self).__init__()\n\n        self.depth_conv = ConvLayer(\n            in_channels,\n            out_channels,\n            kernel_size,\n            stride,\n            groups=in_channels,\n            norm=norm,\n            act_func=act_func,\n            use_bias=use_bias,\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.depth_conv(x)\n        return x\n\n\n\nclass DSConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        use_bias=(True, True),\n        norm=(\"bn2d\", \"bn2d\"),\n        act_func=(\"relu6\", None),\n    ):\n        super(DSConv, self).__init__()\n\n        use_bias = val2tuple(use_bias, 2)\n        norm = val2tuple(norm, 2)\n        act_func = val2tuple(act_func, 2)\n        \n        self.depth_conv = ConvLayer(\n            in_channels,\n            in_channels,\n            kernel_size,\n            stride,\n            groups=in_channels,\n            norm=norm[0],\n            act_func=act_func[0],\n            use_bias=use_bias[0],\n        )\n        self.point_conv = ConvLayer(\n            in_channels,\n            out_channels,\n            1,\n            norm=norm[1],\n            act_func=act_func[1],\n            use_bias=use_bias[1],\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.depth_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\n\nclass DGConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        groups=16,\n        use_bias=(True, True),\n        norm=(\"bn2d\", \"bn2d\"),\n        act_func=(\"relu6\", None),\n    ):\n        super(DGConv, self).__init__()\n\n        use_bias = val2tuple(use_bias, 2)\n        norm = val2tuple(norm, 2)\n        act_func = val2tuple(act_func, 2)\n        self.depth_conv = ConvLayer(\n            in_channels,\n            in_channels,\n            kernel_size,\n            stride,\n            groups=in_channels,\n            norm=norm[0],\n            act_func=act_func[0],\n            use_bias=use_bias[0],\n        )\n        self.point_conv = ConvLayer(\n            in_channels,\n            out_channels,\n            1,\n            groups=groups,\n            norm=norm[1],\n            act_func=act_func[1],\n            use_bias=use_bias[1],\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.depth_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\n\nclass MBConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=6,\n        use_bias=True,\n        norm=(\"bn2d\", \"bn2d\", \"bn2d\"),\n        act_func=(\"relu6\", \"relu6\", None),\n    ):\n        super(MBConv, self).__init__()\n\n        use_bias = val2tuple(use_bias, 3)\n        norm = val2tuple(norm, 3)\n        act_func = val2tuple(act_func, 3)\n        mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels\n\n        self.inverted_conv = ConvLayer(\n            in_channels,\n            mid_channels,\n            1,\n            stride=1,\n            norm=norm[0],\n            act_func=act_func[0],\n            use_bias=use_bias[0],\n        )\n        self.depth_conv = ConvLayer(\n            mid_channels,\n            mid_channels,\n            kernel_size,\n            stride=stride,\n            groups=mid_channels,\n            norm=norm[1],\n            act_func=act_func[1],\n            use_bias=use_bias[1],\n        )\n        self.point_conv = ConvLayer(\n            mid_channels,\n            out_channels,\n            1,\n            norm=norm[2],\n            act_func=act_func[2],\n            use_bias=use_bias[2],\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.inverted_conv(x)\n        x = self.depth_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\nclass FusedMBConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=6,\n        groups=1,\n        use_bias=True,\n        norm=(\"bn2d\", \"bn2d\"),\n        act_func=(\"relu6\", None),\n    ):\n        super().__init__()\n        use_bias = val2tuple(use_bias, 2)\n        norm = val2tuple(norm, 2)\n        act_func = val2tuple(act_func, 2)\n\n        mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels\n\n        self.spatial_conv = ConvLayer(\n            in_channels,\n            mid_channels,\n            kernel_size,\n            stride,\n            groups=groups,\n            use_bias=use_bias[0],\n            norm=norm[0],\n            act_func=act_func[0],\n        )\n        self.point_conv = ConvLayer(\n            mid_channels,\n            out_channels,\n            1,\n            use_bias=use_bias[1],\n            norm=norm[1],\n            act_func=act_func[1],\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.spatial_conv(x)\n        x = self.point_conv(x)\n        return x\n\n\nclass GLUMBConv(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        kernel_size=3,\n        stride=1,\n        mid_channels=None,\n        expand_ratio=6,\n        use_bias=True,\n        norm=(None, None, \"ln2d\"),\n        act_func=(\"silu\", \"silu\", None),\n    ):\n        super().__init__()\n        use_bias = val2tuple(use_bias, 3)\n        norm = val2tuple(norm, 3)\n        act_func = val2tuple(act_func, 3)\n\n        mid_channels = round(in_channels * expand_ratio) if mid_channels is None else mid_channels\n\n        self.glu_act = build_act(act_func[1], inplace=False)\n        self.inverted_conv = ConvLayer(\n            in_channels,\n            mid_channels * 2,\n            1,\n            use_bias=use_bias[0],\n            norm=norm[0],\n            act_func=act_func[0],\n        )\n        self.depth_conv = ConvLayer(\n            mid_channels * 2,\n            mid_channels * 2,\n            kernel_size,\n            stride=stride,\n            groups=mid_channels * 2,\n            use_bias=use_bias[1],\n            norm=norm[1],\n            act_func=None,\n        )\n        self.point_conv = ConvLayer(\n            mid_channels,\n            out_channels,\n            1,\n            use_bias=use_bias[2],\n            norm=norm[2],\n            act_func=act_func[2],\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = self.inverted_conv(x)\n        x = self.depth_conv(x)\n\n        x, gate = torch.chunk(x, 2, dim=1)\n        gate = self.glu_act(gate)\n        x = x * gate\n\n        x = self.point_conv(x)\n        return x\n"
  },
  {
    "path": "nit/models/utils/funcs.py",
    "content": "import torch\nfrom torch import Tensor\nfrom typing import List, Tuple\nfrom itertools import chain\n\ndef modulate(x, shift, scale):\n    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n\n\ndef get_parameter_dtype(parameter: torch.nn.Module):\n    try:\n        params = tuple(parameter.parameters())\n        if len(params) > 0:\n            return params[0].dtype\n\n        buffers = tuple(parameter.buffers())\n        if len(buffers) > 0:\n            return buffers[0].dtype\n\n    except StopIteration:\n        # For torch.nn.DataParallel compatibility in PyTorch 1.5\n\n        def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:\n            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]\n            return tuples\n\n        gen = parameter._named_members(get_members_fn=find_tensor_attributes)\n        first_tuple = next(gen)\n        return first_tuple[1].dtype"
  },
  {
    "path": "nit/models/utils/norms.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport math\n\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\n\nimport triton\nimport triton.language as tl\nimport torch.nn.functional as F\n\n\ndef create_norm(norm_type: str, dim: int, eps: float = 1e-6):\n    \"\"\"\n    Creates the specified normalization layer based on the norm_type.\n\n    Args:\n        norm_type (str): The type of normalization layer to create.\n            Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm\n        dim (int): The dimension of the normalization layer.\n        eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.\n\n    Returns:\n        The created normalization layer.\n\n    Raises:\n        NotImplementedError: If an unknown norm_type is provided.\n    \"\"\"\n    if norm_type == None or norm_type == \"\":\n        return nn.Identity()\n    norm_type = norm_type.lower()  # Normalize to lowercase\n\n    if norm_type == \"layernorm\":\n        return nn.LayerNorm(dim, eps=eps, bias=False)\n    elif norm_type == \"np_layernorm\":\n        return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)\n    elif norm_type == \"np_layernorm_32\":\n        return FP32_Layernorm(dim, eps=eps, elementwise_affine=False, bias=True)\n    elif norm_type == \"layernorm_32\":\n        return FP32_Layernorm(dim, eps=eps, bias=True)\n    elif norm_type == \"rmsnorm\":\n        return RMSNorm(dim, include_weight=True, eps=eps)\n    elif norm_type == \"np_rmsnorm\":\n        return RMSNorm(dim, include_weight=False, eps=1e-6)\n    elif norm_type == \"fused_rmsnorm\":\n        return FusedRMSNorm(dim, eps=1/65536)\n    elif norm_type == \"fused_rmsnorm_32\":\n        return FusedRMSNorm32(dim, eps=1e-6)\n    elif norm_type == 'none':\n        return nn.Identity()\n    else:\n        return nn.Identity()\n\nclass FP32_Layernorm(nn.LayerNorm):\n    def forward(self, inputs: torch.Tensor) -> torch.Tensor:\n        origin_dtype = inputs.dtype\n        if self.bias == None and self.weight == None:\n            return F.layer_norm(\n                input=inputs.float(), \n                normalized_shape=self.normalized_shape, \n                eps=self.eps\n            ).to(origin_dtype)\n        elif self.bias == None:\n            return F.layer_norm(\n                input=inputs.float(), \n                normalized_shape=self.normalized_shape, \n                weight=self.weight.float(), \n                eps=self.eps\n            ).to(origin_dtype)\n        else:\n            return F.layer_norm(\n                input=inputs.float(), \n                normalized_shape=self.normalized_shape, \n                weight=self.weight.float(), \n                bias=self.bias.float(), \n                eps=self.eps\n            ).to(origin_dtype)\n\nclass FusedRMSNorm(nn.Module):\n    \"\"\"Fused RMS Norm, wraps a fused Triton Kernel\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        eps: float = 1e-6,\n    ):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n        self.fused_rms_norm_fn = fused_rms_norm_fn\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"leverages Triton Fused RMS Norm kernel\"\"\"\n        return self.fused_rms_norm_fn(\n            x,\n            self.weight,\n            eps=self.eps,\n        )\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)  # type: ignore\n\nclass FusedRMSNorm32(nn.Module):\n    \"\"\"Fused RMS Norm, wraps a fused Triton Kernel\"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        eps: float = 1e-6,\n    ):\n        super().__init__()\n        self.eps = eps\n        self.weight = nn.Parameter(torch.ones(dim))\n        self.fused_rms_norm_fn = fused_rms_norm_fn\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"leverages Triton Fused RMS Norm kernel\"\"\"\n        dtype = x.dtype\n        return self.fused_rms_norm_fn(\n            x.to(torch.float32),\n            self.weight,\n            eps=self.eps,\n        ).to(dtype)\n\n    def reset_parameters(self):\n        torch.nn.init.ones_(self.weight)  # type: ignore\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6, **block_kwargs):\n        \"\"\"\n        Initialize the RMSNorm normalization layer.\n\n        Args:\n            dim (int): The dimension of the input tensor.\n            include_weight: bool: Whether include weight in the normalization\n            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n\n        Attributes:\n            eps (float): A small value added to the denominator for numerical stability.\n            weight (nn.Parameter): Learnable scaling parameter.\n\n        \"\"\"\n        super().__init__()\n        self.eps = eps\n        if include_weight:\n            self.weight = nn.Parameter(torch.ones(dim))\n        else:\n            self.weight = None\n\n    def _norm(self, x):\n        \"\"\"\n        Apply the RMSNorm normalization to the input tensor.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The normalized tensor.\n\n        \"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass through the RMSNorm layer.\n\n        Args:\n            x (torch.Tensor): The input tensor.\n\n        Returns:\n            torch.Tensor: The output tensor after applying RMSNorm.\n\n        \"\"\"\n        output = self._norm(x.float()).type_as(x)\n        if self.weight == None:\n            return output\n        else:\n            return output * self.weight\n\n\n\n# FusedRMSNorm in Triton\n\n# Credit\n# Tri Dao's Triton LayerNorm: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py\n# Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=1),\n        triton.Config({}, num_warps=2),\n        triton.Config({}, num_warps=4),\n        triton.Config({}, num_warps=8),\n        triton.Config({}, num_warps=16),\n        triton.Config({}, num_warps=32),\n    ],\n    key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_fwd_kernel(\n    X,\n    stride_x,\n    Y,\n    stride_y,\n    W,\n    Rstd,\n    eps,\n    M,  # num rows\n    N,  # num cols\n    block_N: tl.constexpr,\n):\n    row = tl.program_id(0)\n    cols = tl.arange(0, block_N)\n\n    # Load input data and weights\n    mask = cols < N\n    x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n    # Compute mean and variance\n    xbar = tl.where(cols < N, x, 0.0)\n    var = tl.sum(xbar * xbar, axis=0) / N\n    rstd = 1 / tl.sqrt(var + eps)\n\n    # Store the reciprocal standard deviation\n    tl.store(Rstd + row, rstd)\n\n    # Normalize and apply linear transformation\n    x_hat = x * rstd\n    y = x_hat * w\n\n    # Write output\n    tl.store(Y + row * stride_y + cols, y, mask=mask)\n\n\n@triton.autotune(\n    configs=[\n        triton.Config({}, num_warps=1),\n        triton.Config({}, num_warps=2),\n        triton.Config({}, num_warps=4),\n        triton.Config({}, num_warps=8),\n        triton.Config({}, num_warps=16),\n        triton.Config({}, num_warps=32),\n    ],\n    key=[\"N\"],\n)\n@triton.jit\ndef _rms_norm_bwd_kernel_sm(\n    X,\n    stride_x,\n    W,\n    DY,\n    stride_dy,\n    DX,\n    stride_dx,\n    Rstd,\n    DW,\n    eps,\n    M,  # num rows\n    N,  # num cols\n    rows_per_program,\n    block_N: tl.constexpr,\n):\n    row_block_id = tl.program_id(0)\n    row_start = row_block_id * rows_per_program\n    cols = tl.arange(0, block_N)\n    mask = cols < N\n\n    # Load weights\n    w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)\n\n    # Accumulate gradients for weights\n    dw = tl.zeros((block_N,), dtype=tl.float32)\n\n    row_end = min(row_start + rows_per_program, M)\n    for row in range(row_start, row_end):\n        # Load input, output gradient, and reciprocal standard deviation\n        x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32)\n        dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32)\n        rstd = tl.load(Rstd + row)\n\n        # Compute normalized input and gradients\n        x_hat = x * rstd\n        wdy = w * dy\n        dw += dy * x_hat\n        c1 = tl.sum(x_hat * wdy, axis=0) / N\n        dx = (wdy - x_hat * c1) * rstd\n\n        # Store input gradient\n        tl.store(DX + row * stride_dx + cols, dx, mask=mask)\n\n    # Store weight gradients\n    tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n\n\nclass TritonFusedRMSNorm(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, x, weight, eps):\n        x_shape_start = x.shape\n\n        # Flatten input\n        x = x.view(-1, x.shape[-1])\n        if x.stride(-1) != 1:\n            x = x.contiguous()\n        if weight.stride(-1) != 1:\n            weight = weight.contiguous()\n\n        M, N = x.shape\n        y = torch.empty_like(x)\n        rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n\n        max_size = 65536 // x.element_size()\n        block_N = min(max_size, triton.next_power_of_2(N))\n\n        if N > block_N:\n            raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n        grid = lambda meta: (M,)\n        _rms_norm_fwd_kernel[grid](\n            x,\n            x.stride(0),\n            y,\n            y.stride(0),\n            weight,\n            rstd,\n            eps,\n            M,\n            N,\n            block_N,\n        )\n\n        ctx.eps = eps\n        ctx.save_for_backward(x, weight, rstd)\n        ctx.x_shape_start = x_shape_start\n\n        y = y.reshape(x_shape_start)\n        return y\n\n    @staticmethod\n    def backward(ctx, dy):\n        x, weight, rstd = ctx.saved_tensors\n        eps = ctx.eps\n        x_shape_start = ctx.x_shape_start\n\n        # Flatten input and output gradients\n        dy = dy.view(-1, dy.shape[-1])\n        if dy.stride(-1) != 1:\n            dy = dy.contiguous()\n\n        M, N = dy.shape\n        dx = torch.empty_like(x)\n        dw = torch.empty_like(weight)\n\n        sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n        _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n\n        max_size = 65536 // x.element_size()\n        block_N = min(max_size, triton.next_power_of_2(N))\n        rows_per_sm = math.ceil(M / sm_count)\n\n        if N > block_N:\n            raise ValueError(f\"N {N} must be <= {block_N=}\")\n\n        grid = lambda meta: (sm_count,)\n        _rms_norm_bwd_kernel_sm[grid](\n            x,\n            x.stride(0),\n            weight,\n            dy,\n            dy.stride(0),\n            dx,\n            dx.stride(0),\n            rstd,\n            _dw,\n            eps,\n            M,\n            N,\n            rows_per_sm,\n            block_N,\n        )\n        dw = _dw.sum(0).to(weight.dtype)\n        dx = dx.view(x_shape_start)\n        return dx, dw, None\n\n\n# expose fusedRMSNorm as a function\ndef fused_rms_norm_fn(\n    x,\n    weight,\n    eps=1e-6,\n):\n    return TritonFusedRMSNorm.apply(\n        x,\n        weight,\n        eps,\n    )"
  },
  {
    "path": "nit/models/utils/pos_embeds/flash_attn_rotary.py",
    "content": "# Copyright (c) 2023, Tri Dao.\n\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom einops import rearrange, repeat\nfrom flash_attn.ops.triton.rotary import apply_rotary\n\n\ndef rotate_half(x, interleaved=False):\n    if not interleaved:\n        x1, x2 = x.chunk(2, dim=-1)\n        return torch.cat((-x2, x1), dim=-1)\n    else:\n        x1, x2 = x[..., ::2], x[..., 1::2]\n        return rearrange(torch.stack((-x2, x1), dim=-1), \"... d two -> ... (d two)\", two=2)\n\n\ndef apply_rotary_emb_torch(x, cos, sin, interleaved=False):\n    \"\"\"\n    x: (batch_size, seqlen, nheads, headdim)\n    cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)\n    \"\"\"\n    ro_dim = cos.shape[-1] * 2\n    assert ro_dim <= x.shape[-1]\n    cos = repeat(cos, \"... d -> ... 1 (2 d)\" if not interleaved else \"... d -> ... 1 (d 2)\")\n    sin = repeat(sin, \"... d -> ... 1 (2 d)\" if not interleaved else \"... d -> ... 1 (d 2)\")\n    return torch.cat(\n        [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],\n        dim=-1,\n    )\n\n\nclass ApplyRotaryEmb(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        x,\n        cos,\n        sin,\n        interleaved=False,\n        inplace=False,\n        seqlen_offsets: Union[int, torch.Tensor] = 0,\n        cu_seqlens: Optional[torch.Tensor] = None,\n        max_seqlen: Optional[int] = None,\n    ):\n        out = apply_rotary(\n            x,\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=max_seqlen,\n            interleaved=interleaved,\n            inplace=inplace,\n        )\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin, cu_seqlens)  # Can't save int with save_for_backward\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        ctx.inplace = inplace\n        ctx.max_seqlen = max_seqlen\n        return out if not inplace else x\n\n    @staticmethod\n    def backward(ctx, do):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin, cu_seqlens = ctx.saved_tensors\n        # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with\n        # \"[CUDA]: invalid device context\", and cloning makes it work. Idk why. Triton 2.1.0 works.\n        if not ctx.interleaved and not ctx.inplace:\n            do = do.clone()\n        dx = apply_rotary(\n            do,\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            cu_seqlens=cu_seqlens,\n            max_seqlen=ctx.max_seqlen,\n            interleaved=ctx.interleaved,\n            inplace=ctx.inplace,\n            conjugate=True,\n        )\n        return dx, None, None, None, None, None, None, None\n\n\ndef apply_rotary_emb(\n    x,\n    cos,\n    sin,\n    interleaved=False,\n    inplace=False,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[int] = None,\n):\n    \"\"\"\n    Arguments:\n        x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None\n            else (total_seqlen, nheads, headdim)\n        cos, sin: (seqlen_rotary, rotary_dim / 2)\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n            of 1st half and 2nd half (GPT-NeoX style).\n        inplace: if True, apply rotary embedding in-place.\n        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n        cu_seqlens: (batch + 1,) or None\n        max_seqlen: int\n    Return:\n        out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None\n            else (total_seqlen, nheads, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding to the first rotary_dim of x.\n    \"\"\"\n    return ApplyRotaryEmb.apply(\n        x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen\n    )\n\n\n# For backward compatibility\napply_rotary_emb_func = apply_rotary_emb\n\n#TODO need check ,whlzy modified!!!!\nclass ApplyRotaryEmbQKV_(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx,\n        qkv,\n        cos,\n        sin,\n        cos_k=None,\n        sin_k=None,\n        interleaved=False,\n        seqlen_offsets: Union[int, torch.Tensor] = 0,\n        cu_seqlens: Optional[torch.Tensor] = None,\n        max_seqlen: Optional[int] = None,\n    ):\n        total, three, nheads, headdim = qkv.shape # (total, 3, nheads, headdim)\n        assert three == 3\n        if cos_k is None and sin_k is None and qkv.is_contiguous():\n            # Call 1 kernel instead of 2 kernels\n            # We need qkv to be contiguous so that when we reshape to combine (3, nheads)\n            # dimensions, we get the same tensor\n            # qk = rearrange(qkv[:, :, :2], \"b s t h d -> b s (t h) d\")\n            qk = qkv[:, :2].reshape(total, -1, headdim)\n            apply_rotary(\n                qk, \n                cos, \n                sin, \n                seqlen_offsets=seqlen_offsets,\n                cu_seqlens=cu_seqlens,\n                max_seqlen=max_seqlen, \n                interleaved=interleaved, \n                inplace=True\n            )\n        else:\n            cos_k = cos if cos_k is None else cos_k\n            sin_k = sin if sin_k is None else sin_k\n            q, k = qkv[:, 0], qkv[:, 1]\n            apply_rotary(\n                q, \n                cos, \n                sin, \n                seqlen_offsets=seqlen_offsets, \n                cu_seqlens=cu_seqlens, \n                max_seqlen=max_seqlen, \n                interleaved=interleaved, \n                inplace=True\n            )\n            apply_rotary(\n                k, \n                cos_k, \n                sin_k, \n                seqlen_offsets=seqlen_offsets, \n                cu_seqlens=cu_seqlens, \n                max_seqlen=max_seqlen, \n                interleaved=interleaved, \n                inplace=True\n            )\n            ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens)\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        ctx.max_seqlen = max_seqlen\n        return qkv\n\n    @staticmethod\n    def backward(ctx, dqkv):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin, cos_k, sin_k, cu_seqlens = ctx.saved_tensors\n        if cos_k is None and sin_k is None and dqkv.is_contiguous():\n            # Call 1 kernel instead of 2 kernels\n            # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)\n            # dimensions, we get the same tensor\n            dqk = rearrange(dqkv[:, :, :2], \"b t h d -> b (t h) d\") # b for total\n            apply_rotary(\n                dqk,\n                cos,\n                sin,\n                seqlen_offsets=seqlen_offsets,\n                cu_seqlens=cu_seqlens,\n                max_seqlen=ctx.max_seqlen,\n                interleaved=ctx.interleaved,\n                inplace=True,\n                conjugate=True,\n            )\n        else:\n            cos_k = cos if cos_k is None else cos_k\n            sin_k = sin if sin_k is None else sin_k\n            dq, dk = dqkv[:, 0], dqkv[:, 1]\n            apply_rotary(\n                dq, \n                cos, \n                sin, \n                seqlen_offsets=seqlen_offsets,\n                cu_seqlens=cu_seqlens,\n                max_seqlen=ctx.max_seqlen, \n                interleaved=ctx.interleaved, \n                inplace=True, \n                conjugate=True\n            )\n            apply_rotary(\n                dk,\n                cos_k,\n                sin_k,\n                seqlen_offsets=seqlen_offsets,\n                cu_seqlens=cu_seqlens,\n                max_seqlen=ctx.max_seqlen,\n                interleaved=ctx.interleaved,\n                inplace=True,\n                conjugate=True,\n            )\n        return dqkv, None, None, None, None, None, None, None, None\n\n#TODO need check ,whlzy modified!!!!\ndef apply_rotary_emb_qkv_(\n    qkv,\n    cos,\n    sin,\n    cos_k=None,\n    sin_k=None,\n    interleaved=False,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n    cu_seqlens: Optional[torch.Tensor] = None,\n    max_seqlen: Optional[int] = None,\n):\n    \"\"\"\n    Arguments:\n        qkv: (batch_size, seqlen, 3, nheads, headdim)\n        cos, sin: (seqlen, rotary_dim / 2)\n        cos_k, sin_k: (seqlen, rotary_dim / 2), optional\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of\n            1st half and 2nd half (GPT-NeoX style).\n        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n    Return:\n        qkv: (batch_size, seqlen, 3, nheads, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding *inplace* to the first rotary_dim of Q and K.\n    \"\"\"\n    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen)\n\n\nclass ApplyRotaryEmbKV_(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0):\n        batch, seqlen, two, nheads, headdim = kv.shape\n        assert two == 2\n        k = kv[:, :, 0]\n        apply_rotary(\n            k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True\n        )\n        if isinstance(seqlen_offsets, int):\n            ctx.save_for_backward(cos, sin)  # Can't save int with save_for_backward\n            ctx.seqlen_offsets = seqlen_offsets\n        else:\n            ctx.save_for_backward(cos, sin, seqlen_offsets)\n            ctx.seqlen_offsets = None\n        ctx.interleaved = interleaved\n        return kv\n\n    @staticmethod\n    def backward(ctx, dkv):\n        seqlen_offsets = ctx.seqlen_offsets\n        if seqlen_offsets is None:\n            cos, sin, seqlen_offsets = ctx.saved_tensors\n        else:\n            cos, sin = ctx.saved_tensors\n        apply_rotary(\n            dkv[:, :, 0],\n            cos,\n            sin,\n            seqlen_offsets=seqlen_offsets,\n            interleaved=ctx.interleaved,\n            inplace=True,\n            conjugate=True,\n        )\n        return dkv, None, None, None, None\n\n\napply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply\n\n\ndef apply_rotary_emb_kv_(\n    kv,\n    cos,\n    sin,\n    interleaved=False,\n    seqlen_offsets: Union[int, torch.Tensor] = 0,\n):\n    \"\"\"\n    Arguments:\n        kv: (batch_size, seqlen, 2, nheads, headdim)\n        cos, sin: (seqlen, rotary_dim / 2)\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of\n            1st half and 2nd half (GPT-NeoX style).\n        seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n    Return:\n        kv: (batch_size, seqlen, 2, nheads, headdim)\n    rotary_dim must be <= headdim\n    Apply rotary embedding *inplace* to the first rotary_dim of K.\n    \"\"\"\n    return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)\n\n\nclass RotaryEmbedding(torch.nn.Module):\n    \"\"\"\n    The rotary position embeddings from RoFormer_ (Su et. al).\n    A crucial insight from the method is that the query and keys are\n    transformed by rotation matrices which depend on the relative positions.\n\n    Other implementations are available in the Rotary Transformer repo_ and in\n    GPT-NeoX_, GPT-NeoX was an inspiration\n\n    .. _RoFormer: https://arxiv.org/abs/2104.09864\n    .. _repo: https://github.com/ZhuiyiTechnology/roformer\n    .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox\n\n    If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).\n    A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96\n    Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        base=10000.0,\n        interleaved=False,\n        scale_base=None,\n        pos_idx_in_fp32=True,\n        device=None,\n    ):\n        \"\"\"\n        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead\n            of 1st half and 2nd half (GPT-NeoX style).\n        pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,\n            otherwise they might be in lower precision.\n            This option was added because previously (before 2023-07-02), when we construct\n            the position indices, we use the dtype of self.inv_freq. In most cases this would\n            be fp32, but if the model is trained in pure bf16 (not mixed precision), then\n            self.inv_freq would be bf16, and the position indices are also in bf16.\n            Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the\n            embeddings for some positions will coincide.\n            To maintain compatibility with models previously trained in pure bf16,\n            we add this option.\n        \"\"\"\n        super().__init__()\n        self.dim = dim\n        self.base = float(base)\n        self.pos_idx_in_fp32 = pos_idx_in_fp32\n        # Generate and save the inverse frequency buffer (non trainable)\n        inv_freq = self._compute_inv_freq(device)\n        self.register_buffer(\"inv_freq\", inv_freq, persistent=False)\n        self.interleaved = interleaved\n        self.scale_base = scale_base\n        scale = (\n            (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)\n            if scale_base is not None\n            else None\n        )\n        self.register_buffer(\"scale\", scale, persistent=False)\n\n        self._seq_len_cached = 0\n        self._cos_cached = None\n        self._sin_cached = None\n        self._cos_k_cached = None\n        self._sin_k_cached = None\n\n    def _compute_inv_freq(self, device=None):\n        return 1.0 / (\n            self.base\n            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)\n        )\n\n    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):\n        # Reset the tables if the sequence length has changed,\n        # if we're on a new device (possibly due to tracing for instance),\n        # or if we're switching from inference mode to training\n        if (\n            seqlen > self._seq_len_cached\n            or self._cos_cached is None\n            or self._cos_cached.device != device\n            or self._cos_cached.dtype != dtype\n            or (self.training and self._cos_cached.is_inference())\n        ):\n            self._seq_len_cached = seqlen\n            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16\n            # And the output of arange can be quite large, so bf16 would lose a lot of precision.\n            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.\n            if self.pos_idx_in_fp32:\n                t = torch.arange(seqlen, device=device, dtype=torch.float32)\n                # We want fp32 here as well since inv_freq will be multiplied with t, and the output\n                # will be large. Having it in bf16 will lose a lot of precision and cause the\n                # cos & sin output to change significantly.\n                # We want to recompute self.inv_freq if it was not loaded in fp32\n                if self.inv_freq.dtype != torch.float32:\n                    inv_freq = self._compute_inv_freq(device=device)\n                else:\n                    inv_freq = self.inv_freq\n            else:\n                t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)\n                inv_freq = self.inv_freq\n            # Don't do einsum, it converts fp32 to fp16 under AMP\n            # freqs = torch.einsum(\"i,j->ij\", t, self.inv_freq)\n            freqs = torch.outer(t, inv_freq)\n            if self.scale is None:\n                self._cos_cached = torch.cos(freqs).to(dtype)\n                self._sin_cached = torch.sin(freqs).to(dtype)\n            else:\n                power = (\n                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)\n                    - seqlen // 2\n                ) / self.scale_base\n                scale = self.scale.to(device=power.device) ** rearrange(power, \"s -> s 1\")\n                # We want the multiplication by scale to happen in fp32\n                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)\n                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)\n                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)\n                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)\n\n    def forward(\n        self,\n        qkv: torch.Tensor,\n        kv: Optional[torch.Tensor] = None,\n        seqlen_offset: Union[int, torch.Tensor] = 0,\n        max_seqlen: Optional[int] = None,\n    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n        \"\"\"\n        qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,\n             else it's just q of shape (batch, seqlen, nheads, headdim)\n        kv: (batch, seqlen, 2, nheads, headdim)\n        seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.\n            Most commonly used in inference when we have KV cache.\n            If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one\n            should pass in max_seqlen, which will update the cos / sin cache up to that length.\n        Apply rotary embedding *inplace* to qkv and / or kv.\n        \"\"\"\n        seqlen = qkv.shape[1]\n        if max_seqlen is not None:\n            self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)\n        elif isinstance(seqlen_offset, int):\n            self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)\n        if kv is None:\n            if self.scale is None:\n                return apply_rotary_emb_qkv_(\n                    qkv,\n                    self._cos_cached,\n                    self._sin_cached,\n                    interleaved=self.interleaved,\n                    seqlen_offsets=seqlen_offset,\n                )\n            else:\n                return apply_rotary_emb_qkv_(\n                    qkv,\n                    self._cos_cached,\n                    self._sin_cached,\n                    self._cos_k_cached,\n                    self._sin_k_cached,\n                    interleaved=self.interleaved,\n                    seqlen_offsets=seqlen_offset,\n                )\n        else:\n            q = qkv\n            q = apply_rotary_emb_func(\n                q,\n                self._cos_cached,\n                self._sin_cached,\n                interleaved=self.interleaved,\n                inplace=True,\n                seqlen_offsets=seqlen_offset,\n            )\n            if self.scale is None:\n                kv = apply_rotary_emb_kv_(\n                    kv,\n                    self._cos_cached,\n                    self._sin_cached,\n                    interleaved=self.interleaved,\n                    seqlen_offsets=seqlen_offset,\n                )\n            else:\n                kv = apply_rotary_emb_kv_(\n                    kv,\n                    self._cos_k_cached,\n                    self._sin_k_cached,\n                    interleaved=self.interleaved,\n                    seqlen_offsets=seqlen_offset,\n                )\n            return q, kv"
  },
  {
    "path": "nit/models/utils/pos_embeds/rope.py",
    "content": "# --------------------------------------------------------\n# FiT: A Flexible Vision Transformer for Image Generation\n#\n# Based on the following repository\n# https://github.com/lucidrains/rotary-embedding-torch\n# https://github.com/jquesnelle/yarn/blob/HEAD/scaled_rope\n# https://colab.research.google.com/drive/1VI2nhlyKvd5cw4-zHvAIk00cAVj2lCCC#scrollTo=b80b3f37\n# --------------------------------------------------------\n\nimport math\nfrom math import pi\nfrom typing import Optional, Any, Union, Tuple\nimport torch\nfrom torch import nn\n\nfrom einops import rearrange, repeat\nfrom functools import lru_cache\n\n#################################################################################\n#                                 NTK Operations                                #\n#################################################################################\n\ndef find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):\n    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base)) #Inverse dim formula to find number of rotations\n\ndef find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):\n    low = math.floor(find_correction_factor(low_rot, dim, base, max_position_embeddings))\n    high = math.ceil(find_correction_factor(high_rot, dim, base, max_position_embeddings))\n    return max(low, 0), min(high, dim-1) #Clamp values just in case\n\ndef linear_ramp_mask(min, max, dim):\n    if min == max:\n        max += 0.001 #Prevent singularity\n\n    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)\n    ramp_func = torch.clamp(linear_func, 0, 1)\n    return ramp_func\n\ndef find_newbase_ntk(dim, base=10000, scale=1):\n    # Base change formula\n    return base * scale ** (dim / (dim-2))\n\ndef get_mscale(scale=torch.Tensor):\n    # if scale <= 1:\n    #     return 1.0\n    # return 0.1 * math.log(scale) + 1.0\n    return torch.where(scale <= 1., torch.tensor(1.0), 0.1 * torch.log(scale) + 1.0)\n\ndef get_proportion(L_test, L_train):\n    L_test = L_test * 2\n    return torch.where(torch.tensor(L_test/L_train) <= 1., torch.tensor(1.0), torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train))))\n    # return torch.sqrt(torch.log(torch.tensor(L_test))/torch.log(torch.tensor(L_train)))\n\n\n\n#################################################################################\n#                                 Rotate Q or K                                 #\n#################################################################################\n\ndef rotate_half(x):\n    x = rearrange(x, '... (d r) -> ... d r', r = 2)\n    x1, x2 = x.unbind(dim = -1)\n    x = torch.stack((-x2, x1), dim = -1)\n    return rearrange(x, '... d r -> ... (d r)')\n\n\n\n#################################################################################\n#                               Core Vision RoPE                                #\n#################################################################################\n\nclass VisionRotaryEmbedding(nn.Module):\n    def __init__(\n        self,\n        head_dim: int,  # embed dimension for each head\n        custom_freqs: str = 'normal',\n        theta: int = 10000,\n        online_rope: bool = False,\n        max_cached_len: int = 1024,\n        max_pe_len_h: Optional[int] = None,\n        max_pe_len_w: Optional[int] = None,\n        decouple: bool = False,\n        ori_max_pe_len: Optional[int] = None,\n    ):\n        super().__init__()\n        \n        dim = head_dim // 2\n        assert dim % 2 == 0 # accually, this is important\n        self.dim = dim\n        self.custom_freqs = custom_freqs.lower()\n        self.theta = theta\n        self.decouple = decouple\n        self.ori_max_pe_len = ori_max_pe_len\n        \n        self.custom_freqs = custom_freqs.lower()\n        if not online_rope:\n            if self.custom_freqs in ['normal', 'scale1', 'scale2']:\n                freqs_h = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))\n                freqs_w = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))\n            else:\n                if decouple:\n                    freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len_h, ori_max_pe_len)\n                    freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len_w, ori_max_pe_len)\n                else:\n                    max_pe_len = max(max_pe_len_h, max_pe_len_w)\n                    freqs_h = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)\n                    freqs_w = self.get_1d_rope_freqs(theta, dim, max_pe_len, ori_max_pe_len)\n            \n            self.register_buffer('freqs_h', freqs_h, persistent=False)        \n            self.register_buffer('freqs_w', freqs_w, persistent=False)        \n            \n            if max_pe_len_h != None and max_pe_len_w != None and ori_max_pe_len != None:\n                attn_factor = 1.0\n                scale = torch.clamp_min(torch.tensor(max(max_pe_len_h, max_pe_len_w)) / ori_max_pe_len, 1.0)   # dynamic scale\n                self.mscale = get_mscale(scale).to(scale) * attn_factor # Get n-d magnitude scaling corrected for interpolation\n                self.proportion1 = get_proportion(max(max_pe_len_h, max_pe_len_w), ori_max_pe_len)\n                self.proportion2 = get_proportion(max_pe_len_h * max_pe_len_w, ori_max_pe_len ** 2)\n                \n                \n            freqs_h_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_h)\n            freqs_h_cached = repeat(freqs_h_cached, '... n -> ... (n r)', r = 2)\n            self.register_buffer('freqs_h_cached', freqs_h_cached, persistent=False) \n            freqs_w_cached = torch.einsum('..., f -> ... f', torch.arange(max_cached_len), self.freqs_w)\n            freqs_w_cached = repeat(freqs_w_cached, '... n -> ... (n r)', r = 2)\n            self.register_buffer('freqs_w_cached', freqs_w_cached, persistent=False) \n        \n\n    def get_1d_rope_freqs(self, theta, dim, max_pe_len, ori_max_pe_len):\n        # scaling operations for extrapolation\n        assert isinstance(ori_max_pe_len, int)\n        # scale = max_pe_len / ori_max_pe_len\n        if not isinstance(max_pe_len, torch.Tensor):\n            max_pe_len = torch.tensor(max_pe_len)\n        scale = torch.clamp_min(max_pe_len / ori_max_pe_len, 1.0)   # dynamic scale\n        \n        if self.custom_freqs == 'linear': # equal to position interpolation\n            freqs = 1. / torch.einsum('..., f -> ... f', scale, theta ** (torch.arange(0, dim, 2).float() / dim))\n        elif self.custom_freqs == 'ntk-aware' or self.custom_freqs == 'ntk-aware-pro1' or self.custom_freqs == 'ntk-aware-pro2':\n            freqs = 1. / torch.pow(\n                find_newbase_ntk(dim, theta, scale).view(-1, 1), \n                (torch.arange(0, dim, 2).to(scale).float() / dim)\n            ).squeeze()\n        elif self.custom_freqs == 'ntk-by-parts':\n            #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)\n            #Do not change unless there is a good reason for doing so!\n            beta_0 = 1.25\n            beta_1 = 0.75\n            gamma_0 = 16\n            gamma_1 = 2\n            ntk_factor = 1\n            extrapolation_factor = 1\n\n            #Three RoPE extrapolation/interpolation methods\n            freqs_base = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))\n            freqs_linear = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))\n            freqs_ntk = 1. / torch.pow(\n                find_newbase_ntk(dim, theta, scale).view(-1, 1), \n                (torch.arange(0, dim, 2).to(scale).float() / dim)\n            ).squeeze()\n            \n            #Combine NTK and Linear\n            low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len)\n            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * ntk_factor\n            freqs = freqs_linear * (1 - freqs_mask) + freqs_ntk * freqs_mask\n            \n            #Combine Extrapolation and NTK and Linear\n            low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len)\n            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale)) * extrapolation_factor\n            freqs = freqs * (1 - freqs_mask) + freqs_base * freqs_mask\n            \n        elif self.custom_freqs == 'yarn':\n            #Interpolation constants found experimentally for LLaMA (might not be totally optimal though)\n            #Do not change unless there is a good reason for doing so!\n            beta_fast = 32\n            beta_slow = 1\n            extrapolation_factor = 1\n            \n            freqs_extrapolation = 1.0 / (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim))\n            freqs_interpolation = 1.0 / torch.einsum('..., f -> ... f', scale, (theta ** (torch.arange(0, dim, 2).to(scale).float() / dim)))\n\n            low, high = find_correction_range(beta_fast, beta_slow, dim, theta, ori_max_pe_len)\n            freqs_mask = (1 - linear_ramp_mask(low, high, dim // 2).to(scale).float()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation\n            freqs = freqs_interpolation * (1 - freqs_mask) + freqs_extrapolation * freqs_mask            \n        else:\n            raise ValueError(f'Unknown modality {self.custom_freqs}. Only support normal, linear, ntk-aware, ntk-by-parts, yarn!')\n        return freqs\n\n\n    def online_get_2d_rope_from_grid(self, grid, size):\n        '''\n        grid: (B, 2, N)\n            N = H * W\n            the first dimension represents width, and the second reprensents height\n            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n        size: (B, 1, 2), h goes first and w goes last\n        '''\n        size = size.squeeze()   # (B, 1, 2) -> (B, 2)\n        if self.decouple:\n            size_h = size[:, 0]\n            size_w = size[:, 1]\n            freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_h, self.ori_max_pe_len)\n            freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_w, self.ori_max_pe_len)\n        else:\n            size_max = torch.max(size[:, 0], size[:, 1])\n            freqs_h = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)\n            freqs_w = self.get_1d_rope_freqs(self.theta, self.dim, size_max, self.ori_max_pe_len)\n        freqs_w = grid[:, 0][..., None] * freqs_w[:, None, :]\n        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)\n        \n        freqs_h = grid[:, 1][..., None] * freqs_h[:, None, :]\n        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)\n        \n        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)\n        \n        if self.custom_freqs == 'yarn':\n            freqs_cos = freqs.cos() * self.mscale[:, None, None]\n            freqs_sin = freqs.sin() * self.mscale[:, None, None]\n        elif self.custom_freqs == 'ntk-aware-pro1':\n            freqs_cos = freqs.cos() * self.proportion1[:, None, None]\n            freqs_sin = freqs.sin() * self.proportion1[:, None, None]\n        elif self.custom_freqs == 'ntk-aware-pro2':\n            freqs_cos = freqs.cos() * self.proportion2[:, None, None]\n            freqs_sin = freqs.sin() * self.proportion2[:, None, None]\n        else:\n            freqs_cos = freqs.cos()\n            freqs_sin = freqs.sin()\n            \n        return freqs_cos, freqs_sin  \n\n    @lru_cache()\n    def get_2d_rope_from_grid(self, grid):\n        '''\n        grid: (B, 2, N)\n            N = H * W\n            the first dimension represents width, and the second reprensents height\n            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n        '''  \n        freqs_h = torch.einsum('..., f -> ... f', grid[:, 0], self.freqs_h)\n        freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)\n        freqs_w = torch.einsum('..., f -> ... f', grid[:, 1], self.freqs_w)\n        freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)\n        \n        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)\n        \n        if self.custom_freqs == 'yarn':\n            freqs_cos = freqs.cos() * self.mscale\n            freqs_sin = freqs.sin() * self.mscale\n        elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:\n            freqs_cos = freqs.cos() * self.proportion1\n            freqs_sin = freqs.sin() * self.proportion1\n        elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:\n            freqs_cos = freqs.cos() * self.proportion2\n            freqs_sin = freqs.sin() * self.proportion2\n        else:\n            freqs_cos = freqs.cos()\n            freqs_sin = freqs.sin()\n\n        return freqs_cos, freqs_sin\n    \n    @lru_cache()\n    def get_cached_2d_rope_from_grid(self, grid: torch.Tensor):\n        '''\n        grid: (B, 2, N)\n            N = H * W\n            the first dimension represents width, and the second reprensents height\n            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n        '''  \n        if len(grid.shape) == 3:    # (B, 2, N)\n            freqs_h, freqs_w = self.freqs_h_cached[grid[:, 0]], self.freqs_w_cached[grid[:, 1]]\n        elif len(grid.shape) == 2:  # (2, N)\n            freqs_h, freqs_w = self.freqs_h_cached[grid[0]], self.freqs_w_cached[grid[1]]\n        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)\n        \n        if self.custom_freqs == 'yarn':\n            freqs_cos = freqs.cos() * self.mscale\n            freqs_sin = freqs.sin() * self.mscale\n        elif self.custom_freqs in ['ntk-aware-pro1', 'scale1']:\n            freqs_cos = freqs.cos() * self.proportion1\n            freqs_sin = freqs.sin() * self.proportion1\n        elif self.custom_freqs in ['ntk-aware-pro2', 'scale2']:\n            freqs_cos = freqs.cos() * self.proportion2\n            freqs_sin = freqs.sin() * self.proportion2\n        else:\n            freqs_cos = freqs.cos()\n            freqs_sin = freqs.sin()\n        \n        return freqs_cos, freqs_sin\n\n    @lru_cache()\n    def get_cached_21d_rope_from_grid(self, grid: torch.Tensor): # for 3d rope formulation 2 !\n        '''\n        grid: (B, 3, N)\n            N = H * W * T\n            the first dimension represents width, and the second reprensents height, and the third reprensents time\n            e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                    [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n                    [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n        '''   \n        freqs_w, freqs_h = self.freqs_w_cached[grid[:, 0]+grid[:, 2]], self.freqs_h_cached[grid[:, 1]+grid[:, 2]]\n        freqs = torch.cat([freqs_h, freqs_w], dim=-1)   # (B, N, D)\n        \n        if self.custom_freqs == 'yarn':\n            freqs_cos = freqs.cos() * self.mscale\n            freqs_sin = freqs.sin() * self.mscale\n        elif self.custom_freqs == 'ntk-aware-pro1':\n            freqs_cos = freqs.cos() * self.proportion1\n            freqs_sin = freqs.sin() * self.proportion1\n        elif self.custom_freqs == 'ntk-aware-pro2':\n            freqs_cos = freqs.cos() * self.proportion2\n            freqs_sin = freqs.sin() * self.proportion2\n        else:\n            freqs_cos = freqs.cos()\n            freqs_sin = freqs.sin()\n        \n        return freqs_cos, freqs_sin\n\n    def forward(self, x, grid): \n        '''\n        x: (B, n_head, N, D)\n        grid: (B, 2, N)\n        '''\n        # freqs_cos, freqs_sin = self.get_2d_rope_from_grid(grid)\n        # freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)\n        # using cache to accelerate, this is the same with the above codes:\n        freqs_cos, freqs_sin = self.get_cached_2d_rope_from_grid(grid)\n        freqs_cos, freqs_sin = freqs_cos.unsqueeze(1), freqs_sin.unsqueeze(1)\n        return  x * freqs_cos + rotate_half(x) * freqs_sin\n    \n "
  },
  {
    "path": "nit/models/utils/pos_embeds/sincos.py",
    "content": "#################################################################################\n#                   Sine/Cosine Positional Embedding Functions                  #\n#################################################################################\n# modified from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py\n\nimport torch\nimport numpy as np\nfrom einops import rearrange\nimport torch.nn.functional as F\n\n\n\ndef get_2d_sincos_pos_embed(embed_dim, h, w, frac_coord_size=None, scale_ratio=1.0, cls_token=False, extra_tokens=0):\n    \"\"\"\n    args:\n        h / w: int of the grid height / width\n        frac_coord_size: \n            if frac_coord_size != None: \n                fractional coordinates for positional embedding is used\n            else: \n                absolute coordinates for positional embedding is used\n    return:\n        pos_embed: [h*w, embed_dim] or [1+h*w, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = torch.arange(h, dtype=torch.float32)\n    grid_w = torch.arange(w, dtype=torch.float32)\n    grid = torch.meshgrid(grid_w, grid_h, indexing='xy')    # here w goes first\n    grid = torch.stack(grid, dim=0)\n    grid = rearrange(grid, '... -> 1 ...')  # (1, 2, h*w)\n    \n    pos_embed = get_2d_sincos_pos_embed_from_grid(\n        grid, embed_dim, frac_coord_size, scale_ratio\n    )  # 1, L, D\n    if cls_token and extra_tokens > 0:\n        pos_embed = torch.cat([torch.zeros((1, extra_tokens, embed_dim)), pos_embed], dim=1)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):\n    '''\n    grid: (B, 2, N)\n        N = H * W\n        the first dimension represents width, and the second reprensents height\n        e.g.,   [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n    frac_coord_size: \n        if frac_coord_size != None: \n            fractional coordinates for positional embedding is used\n        else: \n            absolute coordinates for positional embedding is used\n    '''\n    assert embed_dim % 2 == 0\n    grid = grid.float()\n    if frac_coord_size != None:\n        assert isinstance(frac_coord_size, (int, float))\n        grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size\n        grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size\n    else:\n        grid_w, grid_h = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio\n    # use half of dimensions to encode grid_h\n    emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, embed_dim // 2)  # (B, N, D/2)\n    emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, embed_dim // 2)  # (B, N, D/2)\n\n    emb = torch.cat([emb_h, emb_w], dim=-1) # (B, L, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(pos, embed_dim):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a batch of list whose positions to be encoded: size (B, N)\n    out: (B, N, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = torch.arange(embed_dim // 2, dtype=torch.float64)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000**omega  # (D/2,)\n\n    out = torch.einsum('BL,D->BLD', pos, omega.to(pos)) # (B, N, D/2), outer product\n\n    emb_sin = torch.sin(out) # (B, N, D/2)\n    emb_cos = torch.cos(out) # (B, N, D/2)\n\n    emb = torch.cat([emb_sin, emb_cos], dim=-1)  # (B, N, D)\n    return emb\n\ndef get_3d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0, time_dim=0):\n    '''\n    grid: (B, 3, N)\n        N = H * W\n        the first dimension represents width, and the second reprensents height\n        e.g.,   [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n                [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n    frac_coord_size: \n        if frac_coord_size != None: \n            fractional coordinates for positional embedding is used\n        else: \n            absolute coordinates for positional embedding is used\n    '''\n    # assert embed_dim % 2 == 0\n    if time_dim == 0:\n        assert embed_dim % 3 == 0\n        dim = embed_dim // 3\n        time_dim = dim\n    else:\n        assert (embed_dim - time_dim) % 2 == 0\n        dim = (embed_dim - time_dim) // 2\n    \n    grid = grid.float()\n    if frac_coord_size != None:\n        assert isinstance(frac_coord_size, (int, float))\n        grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size\n        grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size\n        grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size\n    else:\n        grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio\n    # use half of dimensions to encode grid_h\n    emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim)  # (B, N, D/2)\n    emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim)  # (B, N, D/2)\n    emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, time_dim)  # (B, N, D/2)\n\n    emb = torch.cat([emb_t, emb_h, emb_w], dim=-1) # (B, L, D)\n    return emb\n\ndef get_21d_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):\n    '''\n    grid: (B, 3, N)\n        N = H * W\n        the first dimension represents width, and the second reprensents height\n        e.g.,   [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n                [0. 1. 2. 3. 0. 1. 2. 3. 0. 1. 2. 3.]\n                [0. 0. 0. 0. 1. 1. 1. 1. 2. 2. 2. 2.]\n    frac_coord_size: \n        if frac_coord_size != None: \n            fractional coordinates for positional embedding is used\n        else: \n            absolute coordinates for positional embedding is used\n    '''\n    assert embed_dim % 2 == 0\n    dim = embed_dim // 2\n    \n    grid = grid.float()\n    if frac_coord_size != None:\n        assert isinstance(frac_coord_size, (int, float))\n        grid_w = grid[:, 0] / torch.max(grid[:, 0]) * frac_coord_size\n        grid_h = grid[:, 1] / torch.max(grid[:, 1]) * frac_coord_size\n        grid_t = grid[:, 2] / torch.max(grid[:, 2]) * frac_coord_size\n    else:\n        grid_w, grid_h, grid_t = grid[:, 0]*scale_ratio, grid[:, 1]*scale_ratio, grid[:, 2]*scale_ratio\n    # use half of dimensions to encode grid_h\n    emb_w = get_1d_sincos_pos_embed_from_grid(grid_w, dim)\n    emb_h = get_1d_sincos_pos_embed_from_grid(grid_h, dim)\n    emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim)\n\n    emb = torch.cat([emb_h, emb_w], dim=-1) + emb_t # (B, L, D)\n    return emb\n\ndef get_time_sincos_pos_embed_from_grid(grid, embed_dim, frac_coord_size=None, scale_ratio=1.0):\n    grid = grid.float()\n    grid_t = grid[:, 0]*scale_ratio\n    emb_t = get_1d_sincos_pos_embed_from_grid(grid_t, embed_dim)\n    return emb_t\n\n\n#################################################################################\n#                                 interpolation                                 #\n#################################################################################\n\n\ndef interpolate_sincos_pos_embed(embed_dim, ori_h, ori_w, tgt_h, tgt_w):\n    from src.inf.models.dit import get_2d_sincos_pos_embed\n    pos_embed = get_2d_sincos_pos_embed(embed_dim, ori_h, ori_w)\n    pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)\n    pos_embed = rearrange(pos_embed, '1 (h w) d -> 1 d h w', h=ori_h, w=ori_w)\n    pos_embed = F.interpolate(pos_embed, (tgt_h, tgt_w), mode='bilinear')\n    pos_embed = rearrange(pos_embed, '1 d h w -> 1 (h w) d')\n    return pos_embed\n\ndef interpolate_sincos_pos_index(embed_dim, ori_h, ori_w, tgt_h, tgt_w):\n    from src.inf.models.dit import get_2d_sincos_pos_embed_from_grid\n    grid_h = np.arange(tgt_h, dtype=np.float32) * ori_h / tgt_h\n    grid_w = np.arange(tgt_w, dtype=np.float32) * ori_w / tgt_w\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n    grid = grid.reshape([2, 1, tgt_h, tgt_w])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)\n    return pos_embed\n"
  },
  {
    "path": "nit/schedulers/flow_matching/loss.py",
    "content": "import torch\nimport numpy as np\nimport torch.nn.functional as F\n\ndef mean_flat(x):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return torch.mean(x, dim=list(range(1, len(x.size()))))\n\ndef sum_flat(x):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return torch.sum(x, dim=list(range(1, len(x.size()))))\n\nclass FlowMatchingLoss:\n    def __init__(\n            self,\n            prediction='v',\n            path_type=\"linear\",\n            weighting=\"uniform\",\n            encoders=[], \n            accelerator=None, \n            latents_scale=None, \n            latents_bias=None,\n            P_mean=0.0,\n            P_std=1.0,\n            sigma_data=1.0,\n            unit_variance=False,\n        ):\n        self.prediction = prediction\n        self.weighting = weighting\n        self.path_type = path_type\n        self.encoders = encoders\n        self.accelerator = accelerator\n        self.latents_scale = latents_scale\n        self.latents_bias = latents_bias\n        self.P_mean = P_mean\n        self.P_std = P_std\n        self.sigma_data = sigma_data\n        self.unit_variance = unit_variance\n\n    def interpolant(self, t):\n        if self.path_type == \"linear\":\n            alpha_t = 1 - t\n            sigma_t = t\n            d_alpha_t = -1\n            d_sigma_t =  1\n        elif self.path_type == \"cosine\":\n            alpha_t = torch.cos(t * torch.pi / 2)\n            sigma_t = torch.sin(t * torch.pi / 2)\n            d_alpha_t = -torch.pi / 2 * torch.sin(t * torch.pi / 2)\n            d_sigma_t =  torch.pi / 2 * torch.cos(t * torch.pi / 2)\n        elif self.path_type == 'triangle':\n            alpha_t = torch.cos(t)\n            sigma_t = torch.sin(t)\n            d_alpha_t = -torch.sin(t)\n            d_sigma_t =  torch.cos(t)\n        else:\n            raise NotImplementedError()\n\n        return alpha_t, sigma_t, d_alpha_t, d_sigma_t\n\n    def __call__(self, model, batch_size, images, noises, model_kwargs=None, use_dir_loss=False, zs=[]):\n        if model_kwargs == None:\n            model_kwargs = {}\n        # sample timestep according to log-normal distribution of sigmas following EDM\n        rnd_normal = torch.randn((batch_size))\n        sigma = (rnd_normal * self.P_std + self.P_mean).exp()\n        if self.path_type == \"linear\":      # [0, 1]\n            t = sigma / (1 + sigma)        \n        elif self.path_type == \"cosine\":    # [0, 1]\n            t = 2 / np.pi * torch.atan(sigma)\n        elif self.path_type == 'triangle':  # [0, pi/2]\n            t = torch.atan(sigma / self.sigma_data)\n        else:\n            raise NotImplementedError\n        t = t.to(device=images.device, dtype=images.dtype)\n        \n        time_input = t\n\n        hw_list = model_kwargs['hw_list']\n        seqlens = hw_list[:, 0] * hw_list[:, 1]\n        t = torch.cat([t[i].unsqueeze(0).repeat(seqlens[i], 1, 1, 1) for i in range(batch_size)], dim=0)\n        alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(t)\n        \n        if self.unit_variance:\n            model_input = alpha_t * images / self.sigma_data + sigma_t * noises \n        else:\n            model_input = alpha_t * images + sigma_t * noises\n   \n        if self.prediction == 'v':\n            model_target = d_alpha_t * images + d_sigma_t * noises\n        else:\n            raise NotImplementedError() # TODO: add x or eps prediction\n        \n        model_kwargs['return_zs'] = True\n        if self.unit_variance:\n            model_output, zs_tilde = self.sigma_data * model(model_input, time_input, **model_kwargs)\n        else:\n            model_output, zs_tilde = model(model_input, time_input, **model_kwargs)\n        \n        denoising_loss = mean_flat((model_output - model_target) ** 2)\n        denoising_loss = torch.nan_to_num(denoising_loss, nan=0, posinf=1e5, neginf=-1e5)\n        loss = denoising_loss.mean()\n\n        if use_dir_loss:\n            directional_loss = mean_flat(1 - F.cosine_similarity(model_output, model_target, dim=1))\n            directional_loss = torch.nan_to_num(directional_loss, nan=0, posinf=1e5, neginf=-1e5)\n            loss += directional_loss.mean()\n        \n        proj_loss = 0.\n        if zs != [] and zs != None:\n            for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):\n                proj_loss += 1 - torch.cosine_similarity(z, z_tilde, dim=-1).mean()\n            proj_loss = torch.nan_to_num(proj_loss, nan=0, posinf=1e5, neginf=-1e5)\n\n        return loss, proj_loss\n"
  },
  {
    "path": "nit/schedulers/flow_matching/samplers_c2i.py",
    "content": "import torch\nimport numpy as np\n\n\ndef expand_t_like_x(t, x_cur, hw_list):\n    \"\"\"Function to reshape time t to broadcastable dimension of x\n    Args:\n      t: [batch_dim,], time vector\n      x: [batch_dim,...], data point\n    \"\"\"\n    dims = [1] * (len(x_cur.size()) - 1)\n    seqlens = hw_list[:, 0] * hw_list[:, 1]\n    B = t.shape[0]\n    t = torch.cat([t[i].unsqueeze(0).repeat(int(seqlens[i]), *dims) for i in range(B)], dim=0)\n    return t\n\ndef get_score_from_velocity(vt, xt, t, hw_list, path_type=\"linear\"):\n    \"\"\"Wrapper function: transfrom velocity prediction model to score\n    Args:\n        velocity: [batch_dim, ...] shaped tensor; velocity model output\n        x: [batch_dim, ...] shaped tensor; x_t data point\n        t: [batch_dim,] time tensor\n    \"\"\"\n    t = expand_t_like_x(t, xt, hw_list)\n    if path_type == \"linear\":\n        alpha_t, d_alpha_t = 1 - t, torch.ones_like(t, device=t.device) * -1\n        sigma_t, d_sigma_t = t, torch.ones_like(t, device=t.device)\n    elif path_type == \"cosine\":\n        alpha_t = torch.cos(t * np.pi / 2)\n        sigma_t = torch.sin(t * np.pi / 2)\n        d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)\n        d_sigma_t =  np.pi / 2 * torch.cos(t * np.pi / 2)\n    else:\n        raise NotImplementedError\n\n    mean = xt\n    reverse_alpha_ratio = alpha_t / d_alpha_t\n    var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t\n    score = (reverse_alpha_ratio * vt - mean) / var\n\n    return score\n\n\ndef compute_diffusion(t_cur):\n    return 2 * t_cur\n\n\ndef euler_sampler(\n        model,\n        ag_model,\n        latents,\n        y,\n        hw_list,\n        num_steps=20,\n        heun=False,\n        cfg_scale=1.0,\n        guidance_low=0.0,\n        guidance_high=1.0,\n        path_type=\"linear\", # not used, just for compatability\n        ):\n    # setup conditioning\n    if cfg_scale > 1.0:\n        y_null = torch.tensor([1000] * y.size(0), device=y.device)\n    if ag_model != None:\n        auto_guidance = True\n    else:\n        auto_guidance = False\n\n    _dtype = latents.dtype    \n    t_steps = torch.linspace(1, 0, num_steps+1, dtype=torch.float64)\n    x_next = latents.to(torch.float64)\n    device = x_next.device\n\n    \n    with torch.no_grad():\n        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):\n            x_cur = x_next\n            if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:\n                model_input = torch.cat([x_cur] * 2, dim=0)\n                y_cur = torch.cat([y, y_null], dim=0)\n                hw_list_cur = torch.cat([hw_list, hw_list], dim=0)\n            else:\n                model_input = x_cur\n                y_cur = y      \n                hw_list_cur = hw_list      \n            kwargs = dict(y=y_cur, hw_list=hw_list_cur)\n            time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur\n            d_cur = model(\n                model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n            ).to(torch.float64)\n            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:\n                if auto_guidance:\n                    kwargs = dict(y=y_null, hw_list=hw_list_cur)\n                    time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur\n                    d_cur_uncond = ag_model(\n                        model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n                    ).to(torch.float64)\n                    d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond) \n                else:\n                    d_cur_cond, d_cur_uncond = d_cur.chunk(2)\n                    d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)                \n            x_next = x_cur + (t_next - t_cur) * d_cur\n            if heun and (i < num_steps - 1):\n                if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:\n                    model_input = torch.cat([x_next] * 2)\n                    y_cur = torch.cat([y, y_null], dim=0)\n                    hw_list_cur = torch.cat([hw_list, hw_list], dim=0)\n                else:\n                    model_input = x_next\n                    y_cur = y\n                    hw_list_cur = hw_list\n                kwargs = dict(y=y_cur, hw_list=hw_list_cur)\n                time_input = torch.ones(y_cur.size(0)).to(\n                    device=model_input.device, dtype=torch.float64\n                ) * t_next\n                d_prime = model(\n                    model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n                ).to(torch.float64)\n                if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:\n                    if auto_guidance:\n                        kwargs = dict(y=y_null, hw_list=hw_list_cur)\n                        time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_next\n                        d_prime_uncond = ag_model(\n                            model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n                        ).to(torch.float64)\n                        d_prime = d_prime_uncond + cfg_scale * (d_prime - d_prime_uncond)\n                    else:\n                        d_prime_cond, d_prime_uncond = d_prime.chunk(2)\n                        d_prime = d_prime_uncond + cfg_scale * (d_prime_cond - d_prime_uncond)\n                x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)\n                \n    return x_next\n\n\ndef euler_maruyama_sampler(\n        model,\n        ag_model,\n        latents,\n        y,\n        hw_list,\n        num_steps=20,\n        heun=False,  # not used, just for compatability\n        cfg_scale=1.0,\n        guidance_low=0.0,\n        guidance_high=1.0,\n        path_type=\"linear\",\n        ):\n    # setup conditioning\n    if cfg_scale > 1.0:\n        y_null = torch.tensor([1000] * y.size(0), device=y.device)\n    if ag_model != None:\n        auto_guidance = True\n    else:\n        auto_guidance = False\n\n    _dtype = latents.dtype\n\n    t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64)\n    t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)])\n    x_next = latents.to(torch.float64)\n    device = x_next.device\n\n    with torch.no_grad():\n        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):\n            dt = t_next - t_cur\n            x_cur = x_next\n            if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:\n                model_input = torch.cat([x_cur] * 2, dim=0)\n                y_cur = torch.cat([y, y_null], dim=0)\n                hw_list_cur = torch.cat([hw_list, hw_list], dim=0)\n            else:\n                model_input = x_cur\n                y_cur = y            \n                hw_list_cur = hw_list\n            kwargs = dict(y=y_cur, hw_list=hw_list_cur)\n            time_input = torch.ones(y_cur.size(0)).to(device=device, dtype=torch.float64) * t_cur\n            diffusion = compute_diffusion(t_cur)            \n            eps_i = torch.randn_like(x_cur).to(device)\n            deps = eps_i * torch.sqrt(torch.abs(dt))\n\n            # compute drift\n            v_cur = model(\n                model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n            ).to(torch.float64)\n            s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type)\n            d_cur = v_cur - 0.5 * diffusion * s_cur\n            if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:\n                if auto_guidance:\n                    kwargs = dict(y=y_null, hw_list=hw_list_cur)\n                    time_input = torch.ones(y_null.size(0)).to(device=device, dtype=torch.float64) * t_cur\n                    diffusion = compute_diffusion(t_cur)            \n                    eps_i = torch.randn_like(x_cur).to(device)\n                    deps = eps_i * torch.sqrt(torch.abs(dt))\n\n                    # compute drift\n                    v_cur_uncond = ag_model(\n                        model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n                    ).to(torch.float64)\n                    s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type)\n                    d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond\n                    d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond)\n                else:\n                    d_cur_cond, d_cur_uncond = d_cur.chunk(2)\n                    d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)\n\n            x_next =  x_cur + d_cur * dt + torch.sqrt(diffusion) * deps\n    \n    # last step\n    t_cur, t_next = t_steps[-2], t_steps[-1]\n    dt = t_next - t_cur\n    x_cur = x_next\n    if not auto_guidance and cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:\n        model_input = torch.cat([x_cur] * 2, dim=0)\n        y_cur = torch.cat([y, y_null], dim=0)\n        hw_list_cur = torch.cat([hw_list, hw_list], dim=0)\n    else:\n        model_input = x_cur\n        y_cur = y            \n        hw_list_cur = hw_list\n    kwargs = dict(y=y_cur, hw_list=hw_list_cur)\n    time_input = torch.ones(y_cur.size(0)).to(\n        device=device, dtype=torch.float64\n    ) * t_cur\n    \n    # compute drift\n    v_cur = model(\n        model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n    ).to(torch.float64)\n    s_cur = get_score_from_velocity(v_cur, model_input, time_input, hw_list_cur, path_type=path_type)\n    diffusion = compute_diffusion(t_cur)\n    d_cur = v_cur - 0.5 * diffusion * s_cur\n    if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:\n        if auto_guidance:\n            kwargs = dict(y=y_null, hw_list=hw_list_cur)\n            time_input = torch.ones(y_null.size(0)).to(\n                device=device, dtype=torch.float64\n            ) * t_cur\n            \n            # compute drift\n            v_cur_uncond = ag_model(\n                model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs\n            ).to(torch.float64)\n            s_cur_uncond = get_score_from_velocity(v_cur_uncond, model_input, time_input, hw_list_cur, path_type=path_type)\n            diffusion = compute_diffusion(t_cur)\n            d_cur_uncond = v_cur_uncond - 0.5 * diffusion * s_cur_uncond\n            d_cur = d_cur_uncond + cfg_scale * (d_cur - d_cur_uncond)\n        else:\n            d_cur_cond, d_cur_uncond = d_cur.chunk(2)\n            d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)\n\n    mean_x = x_cur + dt * d_cur\n                    \n    return mean_x\n"
  },
  {
    "path": "nit/utils/__init__.py",
    "content": "from .misc_utils import *\nfrom .train_utils import *\nfrom .eval_utils import *\nfrom .gpu_memory_monitor import *"
  },
  {
    "path": "nit/utils/deepspeed_zero_to_fp32.py",
    "content": "#!/usr/bin/env python\n\n# Copyright (c) Microsoft Corporation.\n# SPDX-License-Identifier: Apache-2.0\n\n# DeepSpeed Team\n\n# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets\n# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in\n# the future. Once extracted, the weights don't require DeepSpeed and can be used in any\n# application.\n#\n# example: python zero_to_fp32.py . pytorch_model.bin\n\nimport argparse\nimport torch\nimport glob\nimport math\nimport os\nimport re\nfrom collections import OrderedDict\nfrom dataclasses import dataclass\n\n# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with\n# DeepSpeed data structures it has to be available in the current python environment.\nfrom deepspeed.utils import logger\nfrom deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,\n                                            FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,\n                                            FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)\n\n\n@dataclass\nclass zero_model_state:\n    buffers: dict()\n    param_shapes: dict()\n    shared_params: list\n    ds_version: int\n    frozen_param_shapes: dict()\n    frozen_param_fragments: dict()\n\n\ndebug = 0\n\n# load to cpu\ndevice = torch.device('cpu')\n\n\ndef atoi(text):\n    return int(text) if text.isdigit() else text\n\n\ndef natural_keys(text):\n    '''\n    alist.sort(key=natural_keys) sorts in human order\n    http://nedbatchelder.com/blog/200712/human_sorting.html\n    (See Toothy's implementation in the comments)\n    '''\n    return [atoi(c) for c in re.split(r'(\\d+)', text)]\n\n\ndef get_model_state_file(checkpoint_dir, zero_stage):\n    if not os.path.isdir(checkpoint_dir):\n        raise FileNotFoundError(f\"Directory '{checkpoint_dir}' doesn't exist\")\n\n    # there should be only one file\n    if zero_stage <= 2:\n        file = os.path.join(checkpoint_dir, \"mp_rank_00_model_states.pt\")\n    elif zero_stage == 3:\n        file = os.path.join(checkpoint_dir, \"zero_pp_rank_0_mp_rank_00_model_states.pt\")\n\n    if not os.path.exists(file):\n        raise FileNotFoundError(f\"can't find model states file at '{file}'\")\n\n    return file\n\n\ndef get_checkpoint_files(checkpoint_dir, glob_pattern):\n    # XXX: need to test that this simple glob rule works for multi-node setup too\n    ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)\n\n    if len(ckpt_files) == 0:\n        raise FileNotFoundError(f\"can't find {glob_pattern} files in directory '{checkpoint_dir}'\")\n\n    return ckpt_files\n\n\ndef get_optim_files(checkpoint_dir):\n    return get_checkpoint_files(checkpoint_dir, \"*_optim_states.pt\")\n\n\ndef get_model_state_files(checkpoint_dir):\n    return get_checkpoint_files(checkpoint_dir, \"*_model_states.pt\")\n\n\ndef parse_model_states(files):\n    zero_model_states = []\n    for file in files:\n        state_dict = torch.load(file, map_location=device)\n\n        if BUFFER_NAMES not in state_dict:\n            raise ValueError(f\"{file} is not a model state checkpoint\")\n        buffer_names = state_dict[BUFFER_NAMES]\n        if debug:\n            print(\"Found buffers:\", buffer_names)\n\n        # recover just the buffers while restoring them to fp32 if they were saved in fp16\n        buffers = {k: v.float() for k, v in state_dict[\"module\"].items() if k in buffer_names}\n        param_shapes = state_dict[PARAM_SHAPES]\n\n        # collect parameters that are included in param_shapes\n        param_names = []\n        for s in param_shapes:\n            for name in s.keys():\n                param_names.append(name)\n\n        # update with frozen parameters\n        frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)\n        if frozen_param_shapes is not None:\n            if debug:\n                print(f\"Found frozen_param_shapes: {frozen_param_shapes}\")\n            param_names += list(frozen_param_shapes.keys())\n\n        # handle shared params\n        shared_params = [[k, v] for k, v in state_dict[\"shared_params\"].items()]\n\n        ds_version = state_dict.get(DS_VERSION, None)\n\n        frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)\n\n        z_model_state = zero_model_state(buffers=buffers,\n                                         param_shapes=param_shapes,\n                                         shared_params=shared_params,\n                                         ds_version=ds_version,\n                                         frozen_param_shapes=frozen_param_shapes,\n                                         frozen_param_fragments=frozen_param_fragments)\n        zero_model_states.append(z_model_state)\n\n    return zero_model_states\n\n\ndef parse_optim_states(files, ds_checkpoint_dir):\n\n    total_files = len(files)\n    state_dicts = []\n    for f in files:\n        state_dict = torch.load(f, map_location=device)\n        # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights\n        # and also handle the case where it was already removed by another helper script\n        state_dict[\"optimizer_state_dict\"].pop(\"optimizer_state_dict\", None)\n        state_dicts.append(state_dict)\n\n    if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:\n        raise ValueError(f\"{files[0]} is not a zero checkpoint\")\n    zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]\n    world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]\n\n    # For ZeRO-2 each param group can have different partition_count as data parallelism for expert\n    # parameters can be different from data parallelism for non-expert parameters. So we can just\n    # use the max of the partition_count to get the dp world_size.\n\n    if type(world_size) is list:\n        world_size = max(world_size)\n\n    if world_size != total_files:\n        raise ValueError(\n            f\"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. \"\n            \"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes.\"\n        )\n\n    # the groups are named differently in each stage\n    if zero_stage <= 2:\n        fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS\n    elif zero_stage == 3:\n        fp32_groups_key = FP32_FLAT_GROUPS\n    else:\n        raise ValueError(f\"unknown zero stage {zero_stage}\")\n\n    if zero_stage <= 2:\n        fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]\n    elif zero_stage == 3:\n        # if there is more than one param group, there will be multiple flattened tensors - one\n        # flattened tensor per group - for simplicity merge them into a single tensor\n        #\n        # XXX: could make the script more memory efficient for when there are multiple groups - it\n        # will require matching the sub-lists of param_shapes for each param group flattened tensor\n\n        fp32_flat_groups = [\n            torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))\n        ]\n\n    return zero_stage, world_size, fp32_flat_groups\n\n\ndef _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):\n    \"\"\"\n    Returns fp32 state_dict reconstructed from ds checkpoint\n\n    Args:\n        - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)\n\n    \"\"\"\n    print(f\"Processing zero checkpoint '{ds_checkpoint_dir}'\")\n\n    optim_files = get_optim_files(ds_checkpoint_dir)\n    zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)\n    print(f\"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}\")\n\n    model_files = get_model_state_files(ds_checkpoint_dir)\n\n    zero_model_states = parse_model_states(model_files)\n    print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')\n\n    if zero_stage <= 2:\n        return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,\n                                                          exclude_frozen_parameters)\n    elif zero_stage == 3:\n        return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,\n                                                          exclude_frozen_parameters)\n\n\ndef _zero2_merge_frozen_params(state_dict, zero_model_states):\n    if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:\n        return\n\n    frozen_param_shapes = zero_model_states[0].frozen_param_shapes\n    frozen_param_fragments = zero_model_states[0].frozen_param_fragments\n\n    if debug:\n        num_elem = sum(s.numel() for s in frozen_param_shapes.values())\n        print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')\n\n        wanted_params = len(frozen_param_shapes)\n        wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())\n        avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])\n        print(f'Frozen params: Have {avail_numel} numels to process.')\n        print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')\n\n    total_params = 0\n    total_numel = 0\n    for name, shape in frozen_param_shapes.items():\n        total_params += 1\n        unpartitioned_numel = shape.numel()\n        total_numel += unpartitioned_numel\n\n        state_dict[name] = frozen_param_fragments[name]\n\n        if debug:\n            print(f\"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} \")\n\n    print(f\"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements\")\n\n\ndef _has_callable(obj, fn):\n    attr = getattr(obj, fn, None)\n    return callable(attr)\n\n\ndef _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):\n    param_shapes = zero_model_states[0].param_shapes\n\n    # Reconstruction protocol:\n    #\n    # XXX: document this\n\n    if debug:\n        for i in range(world_size):\n            for j in range(len(fp32_flat_groups[0])):\n                print(f\"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}\")\n\n    # XXX: memory usage doubles here (zero2)\n    num_param_groups = len(fp32_flat_groups[0])\n    merged_single_partition_of_fp32_groups = []\n    for i in range(num_param_groups):\n        merged_partitions = [sd[i] for sd in fp32_flat_groups]\n        full_single_fp32_vector = torch.cat(merged_partitions, 0)\n        merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)\n    avail_numel = sum(\n        [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])\n\n    if debug:\n        wanted_params = sum([len(shapes) for shapes in param_shapes])\n        wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])\n        # not asserting if there is a mismatch due to possible padding\n        print(f\"Have {avail_numel} numels to process.\")\n        print(f\"Need {wanted_numel} numels in {wanted_params} params.\")\n\n    # params\n    # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support\n    # out-of-core computing solution\n    total_numel = 0\n    total_params = 0\n    for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):\n        offset = 0\n        avail_numel = full_single_fp32_vector.numel()\n        for name, shape in shapes.items():\n\n            unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)\n            total_numel += unpartitioned_numel\n            total_params += 1\n\n            if debug:\n                print(f\"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} \")\n            state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)\n            offset += unpartitioned_numel\n\n        # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and\n        # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex\n        # paddings performed in the code it's almost impossible to predict the exact numbers w/o the\n        # live optimizer object, so we are checking that the numbers are within the right range\n        align_to = 2 * world_size\n\n        def zero2_align(x):\n            return align_to * math.ceil(x / align_to)\n\n        if debug:\n            print(f\"original offset={offset}, avail_numel={avail_numel}\")\n\n        offset = zero2_align(offset)\n        avail_numel = zero2_align(avail_numel)\n\n        if debug:\n            print(f\"aligned  offset={offset}, avail_numel={avail_numel}\")\n\n        # Sanity check\n        if offset != avail_numel:\n            raise ValueError(f\"consumed {offset} numels out of {avail_numel} - something is wrong\")\n\n    print(f\"Reconstructed fp32 state dict with {total_params} params {total_numel} elements\")\n\n\ndef _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,\n                                               exclude_frozen_parameters):\n    state_dict = OrderedDict()\n\n    # buffers\n    buffers = zero_model_states[0].buffers\n    state_dict.update(buffers)\n    if debug:\n        print(f\"added {len(buffers)} buffers\")\n\n    if not exclude_frozen_parameters:\n        _zero2_merge_frozen_params(state_dict, zero_model_states)\n\n    _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)\n\n    # recover shared parameters\n    for pair in zero_model_states[0].shared_params:\n        if pair[1] in state_dict:\n            state_dict[pair[0]] = state_dict[pair[1]]\n\n    return state_dict\n\n\ndef zero3_partitioned_param_info(unpartitioned_numel, world_size):\n    remainder = unpartitioned_numel % world_size\n    padding_numel = (world_size - remainder) if remainder else 0\n    partitioned_numel = math.ceil(unpartitioned_numel / world_size)\n    return partitioned_numel, padding_numel\n\n\ndef _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):\n    if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:\n        return\n\n    if debug:\n        for i in range(world_size):\n            num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())\n            print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')\n\n        frozen_param_shapes = zero_model_states[0].frozen_param_shapes\n        wanted_params = len(frozen_param_shapes)\n        wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())\n        avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size\n        print(f'Frozen params: Have {avail_numel} numels to process.')\n        print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')\n\n    total_params = 0\n    total_numel = 0\n    for name, shape in zero_model_states[0].frozen_param_shapes.items():\n        total_params += 1\n        unpartitioned_numel = shape.numel()\n        total_numel += unpartitioned_numel\n\n        param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)\n        state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)\n\n        partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)\n\n        if debug:\n            print(\n                f\"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}\"\n            )\n\n    print(f\"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements\")\n\n\ndef _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):\n    param_shapes = zero_model_states[0].param_shapes\n    avail_numel = fp32_flat_groups[0].numel() * world_size\n    # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each\n    # param, re-consolidating each param, while dealing with padding if any\n\n    # merge list of dicts, preserving order\n    param_shapes = {k: v for d in param_shapes for k, v in d.items()}\n\n    if debug:\n        for i in range(world_size):\n            print(f\"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}\")\n\n        wanted_params = len(param_shapes)\n        wanted_numel = sum(shape.numel() for shape in param_shapes.values())\n        # not asserting if there is a mismatch due to possible padding\n        avail_numel = fp32_flat_groups[0].numel() * world_size\n        print(f\"Trainable params: Have {avail_numel} numels to process.\")\n        print(f\"Trainable params: Need {wanted_numel} numels in {wanted_params} params.\")\n\n    # params\n    # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support\n    # out-of-core computing solution\n    offset = 0\n    total_numel = 0\n    total_params = 0\n    for name, shape in param_shapes.items():\n\n        unpartitioned_numel = shape.numel()\n        total_numel += unpartitioned_numel\n        total_params += 1\n\n        partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)\n\n        if debug:\n            print(\n                f\"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}\"\n            )\n\n        # XXX: memory usage doubles here\n        state_dict[name] = torch.cat(\n            tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),\n            0).narrow(0, 0, unpartitioned_numel).view(shape)\n        offset += partitioned_numel\n\n    offset *= world_size\n\n    # Sanity check\n    if offset != avail_numel:\n        raise ValueError(f\"consumed {offset} numels out of {avail_numel} - something is wrong\")\n\n    print(f\"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements\")\n\n\ndef _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,\n                                               exclude_frozen_parameters):\n    state_dict = OrderedDict()\n\n    # buffers\n    buffers = zero_model_states[0].buffers\n    state_dict.update(buffers)\n    if debug:\n        print(f\"added {len(buffers)} buffers\")\n\n    if not exclude_frozen_parameters:\n        _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)\n\n    _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)\n\n    # recover shared parameters\n    for pair in zero_model_states[0].shared_params:\n        if pair[1] in state_dict:\n            state_dict[pair[0]] = state_dict[pair[1]]\n\n    return state_dict\n\n\ndef get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):\n    \"\"\"\n    Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with\n    ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example\n    via a model hub.\n\n    Args:\n        - ``checkpoint_dir``: path to the desired checkpoint folder\n        - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``\n        - ``exclude_frozen_parameters``: exclude frozen parameters\n\n    Returns:\n        - pytorch ``state_dict``\n\n    Note: this approach may not work if your application doesn't have sufficient free CPU memory and\n    you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with\n    the checkpoint.\n\n    A typical usage might be ::\n\n        from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint\n        # do the training and checkpoint saving\n        state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu\n        model = model.cpu() # move to cpu\n        model.load_state_dict(state_dict)\n        # submit to model hub or save the model to share with others\n\n    In this example the ``model`` will no longer be usable in the deepspeed context of the same\n    application. i.e. you will need to re-initialize the deepspeed engine, since\n    ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.\n\n    If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.\n\n    \"\"\"\n    if tag is None:\n        latest_path = os.path.join(checkpoint_dir, 'latest')\n        if os.path.isfile(latest_path):\n            with open(latest_path, 'r') as fd:\n                tag = fd.read().strip()\n        else:\n            raise ValueError(f\"Unable to find 'latest' file at {latest_path}\")\n\n    ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)\n\n    if not os.path.isdir(ds_checkpoint_dir):\n        raise FileNotFoundError(f\"Directory '{ds_checkpoint_dir}' doesn't exist\")\n\n    return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)\n\n\ndef convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False):\n    \"\"\"\n    Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be\n    loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.\n\n    Args:\n        - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)\n        - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)\n        - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``\n        - ``exclude_frozen_parameters``: exclude frozen parameters\n    \"\"\"\n\n    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)\n    print(f\"Saving fp32 state dict to {output_file}\")\n    torch.save(state_dict, output_file)\n\n\ndef load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):\n    \"\"\"\n    1. Put the provided model to cpu\n    2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``\n    3. Load it into the provided model\n\n    Args:\n        - ``model``: the model object to update\n        - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)\n        - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``\n\n    Returns:\n        - ``model`: modified model\n\n    Make sure you have plenty of CPU memory available before you call this function. If you don't\n    have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it\n    conveniently placed for you in the checkpoint folder.\n\n    A typical usage might be ::\n\n        from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint\n        model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)\n        # submit to model hub or save the model to share with others\n\n    Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context\n    of the same application. i.e. you will need to re-initialize the deepspeed engine, since\n    ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.\n\n    \"\"\"\n    logger.info(f\"Extracting fp32 weights\")\n    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)\n\n    logger.info(f\"Overwriting model with fp32 weights\")\n    model = model.cpu()\n    model.load_state_dict(state_dict, strict=False)\n\n    return model\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"checkpoint_dir\",\n                        type=str,\n                        help=\"path to the desired checkpoint folder, e.g., path/checkpoint-12\")\n    parser.add_argument(\n        \"output_file\",\n        type=str,\n        help=\"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)\")\n    parser.add_argument(\"-t\",\n                        \"--tag\",\n                        type=str,\n                        default=None,\n                        help=\"checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1\")\n    parser.add_argument(\"--exclude_frozen_parameters\", action='store_true', help=\"exclude frozen parameters\")\n    parser.add_argument(\"-d\", \"--debug\", action='store_true', help=\"enable debug\")\n    args = parser.parse_args()\n\n    debug = args.debug\n\n    convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,\n                                               args.output_file,\n                                               tag=args.tag,\n                                               exclude_frozen_parameters=args.exclude_frozen_parameters)\n"
  },
  {
    "path": "nit/utils/ema.py",
    "content": "import torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\n\n\n\n@torch.no_grad()\ndef update_ema(ema_model, model, decay=0.9999):\n    \"\"\"\n    Step the EMA model towards the current model.\n    \"\"\"\n    if hasattr(model, 'module'):\n        model = model.module\n    if hasattr(ema_model, 'module'):\n        ema_model = ema_model.module\n    ema_params = OrderedDict(ema_model.named_parameters())\n    model_params = OrderedDict(model.named_parameters())\n    \n    for name, param in model_params.items():\n        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed\n        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)\n    "
  },
  {
    "path": "nit/utils/eval_utils.py",
    "content": "from PIL import Image\nimport numpy as np\nfrom tqdm import tqdm\nimport torch\nimport re\nimport os\n\nfrom safetensors.torch import load_file\n\n\ndef create_npz_from_sample_folder(sample_dir, num=50_000):\n    \"\"\"\n    Builds a single .npz file from a folder of .png samples.\n    \"\"\"\n    samples = []\n    imgs = sorted(os.listdir(sample_dir), key=lambda x: int(x.split('.')[0]))\n    print(len(imgs))\n    assert len(imgs) >= num\n    for i in tqdm(range(num), desc=\"Building .npz file from samples\"):\n        sample_pil = Image.open(f\"{sample_dir}/{imgs[i]}\")\n        sample_np = np.asarray(sample_pil).astype(np.uint8)\n        samples.append(sample_np)\n    samples = np.stack(samples)\n    assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)\n    npz_path = f\"{sample_dir}.npz\"\n    np.savez(npz_path, arr_0=samples)\n    print(f\"Saved .npz file to {npz_path} [shape={samples.shape}].\")\n    return npz_path\n\ndef init_from_ckpt(\n    model, checkpoint_dir, ignore_keys=None, verbose=False\n) -> None: \n    if checkpoint_dir.endswith(\".safetensors\"):\n        model_state_dict=load_file(checkpoint_dir, device='cpu')\n    else:\n        model_state_dict=torch.load(checkpoint_dir,  map_location=\"cpu\")\n    model_new_ckpt=dict()\n    for i in model_state_dict.keys():\n        model_new_ckpt[i] = model_state_dict[i]\n    keys = list(model_new_ckpt.keys())\n    for k in keys:\n        if ignore_keys:\n            for ik in ignore_keys:\n                if ik in k:\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del model_new_ckpt[k]\n    missing, unexpected = model.load_state_dict(model_new_ckpt, strict=False)\n    if verbose:\n        print(\n            f\"Restored with {len(missing)} missing and {len(unexpected)} unexpected keys\"\n        )\n        if len(missing) > 0:\n            print(f\"Missing Keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"Unexpected Keys: {unexpected}\")\n    if verbose:\n        print(\"\")\n\n\ndef none_or_str(value):\n    if value == 'None':\n        return None\n    return value\n\ndef parse_sde_args(parser):\n    group = parser.add_argument_group(\"SDE arguments\")\n    group.add_argument(\"--sde-sampling-method\", type=str, default=\"Euler\", choices=[\"Euler\", \"Heun\"])\n    group.add_argument(\"--diffusion-form\", type=str, default=\"sigma\", \\\n                        choices=[\"constant\", \"SBDM\", \"sigma\", \"linear\", \"decreasing\", \"increasing-decreasing\"],\\\n                        help=\"form of diffusion coefficient in the SDE\")\n    group.add_argument(\"--diffusion-norm\", type=float, default=1.0)\n    group.add_argument(\"--last-step\", type=none_or_str, default=\"Mean\", choices=[None, \"Mean\", \"Tweedie\", \"Euler\"],\\\n                        help=\"form of last step taken in the SDE\")\n    group.add_argument(\"--last-step-size\", type=float, default=0.04, \\\n                        help=\"size of the last step taken\")\n\ndef parse_ode_args(parser):\n    group = parser.add_argument_group(\"ODE arguments\")\n    group.add_argument(\"--ode-sampling-method\", type=str, default=\"dopri5\", help=\"blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq\")\n    group.add_argument(\"--atol\", type=float, default=1e-6, help=\"Absolute tolerance\")\n    group.add_argument(\"--rtol\", type=float, default=1e-3, help=\"Relative tolerance\")\n    group.add_argument(\"--reverse\", action=\"store_true\")\n    group.add_argument(\"--likelihood\", action=\"store_true\")\n\n# ode solvers:\n# - Adaptive-step:\n#   - dopri8 Runge-Kutta 7(8) of Dormand-Prince-Shampine\n#   - dopri5 Runge-Kutta 4(5) of Dormand-Prince [default].\n#   - bosh3 Runge-Kutta 2(3) of Bogacki-Shampine\n#   - adaptive_heun Runge-Kutta 1(2)\n# - Fixed-step:\n#   - euler Euler method.\n#   - midpoint Midpoint method.\n#   - rk4 Fourth-order Runge-Kutta with 3/8 rule.\n#   - explicit_adams Explicit Adams.\n#   - implicit_adams Implicit Adams."
  },
  {
    "path": "nit/utils/freeze.py",
    "content": "from diffusers.utils import logging\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\ndef freeze_model(model, trainable_modules={}, verbose=False):\n    logger.info(\"Start freeze\")\n    for name, param in model.named_parameters():\n        param.requires_grad = False\n        if verbose:\n            logger.info(\"freeze moduel: \"+str(name))\n        for trainable_module_name in trainable_modules:\n            if trainable_module_name in name:\n                param.requires_grad = True\n                if verbose:\n                    logger.info(\"unfreeze moduel: \"+str(name))\n                break\n    logger.info(\"End freeze\")\n    params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]\n    params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]\n    logger.info(f\"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M\")\n    logger.info(f\"Freeze Module Parameters: {sum(params_freeze) / 1e6} M\")\n    return "
  },
  {
    "path": "nit/utils/gpu_memory_monitor.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD-style license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport os\nfrom collections import namedtuple\nfrom datetime import datetime\nfrom typing import Any, Dict, Optional\n\nimport torch\n\n# named tuple for passing GPU memory stats for logging\nGPUMemStats = namedtuple(\n    \"GPUMemStats\",\n    [\n        \"max_active_gib\",\n        \"max_active_pct\",\n        \"max_reserved_gib\",\n        \"max_reserved_pct\",\n        \"num_alloc_retries\",\n        \"num_ooms\",\n    ],\n)\n\n\nclass GPUMemoryMonitor:\n    def __init__(self, logger, device: str = \"cuda:0\"):\n        self.device = torch.device(device)  # device object\n        self.device_name = torch.cuda.get_device_name(self.device)\n        self.device_index = torch.cuda.current_device()\n        self.device_capacity = torch.cuda.get_device_properties(\n            self.device\n        ).total_memory\n        self.device_capacity_gib = self._to_gib(self.device_capacity)\n        \n        self.logger = logger\n\n        torch.cuda.reset_peak_memory_stats()\n        torch.cuda.empty_cache()\n\n    def _to_gib(self, memory_in_bytes):\n        # NOTE: GiB (gibibyte) is 1024, vs GB is 1000\n        _gib_in_bytes = 1024 * 1024 * 1024\n        memory_in_gib = memory_in_bytes / _gib_in_bytes\n        return memory_in_gib\n\n    def _to_pct(self, memory):\n        return 100 * memory / self.device_capacity\n\n    def get_peak_stats(self):\n        cuda_info = torch.cuda.memory_stats(self.device)\n\n        max_active = cuda_info[\"active_bytes.all.peak\"]\n        max_active_gib = self._to_gib(max_active)\n        max_active_pct = self._to_pct(max_active)\n\n        max_reserved = cuda_info[\"reserved_bytes.all.peak\"]\n        max_reserved_gib = self._to_gib(max_reserved)\n        max_reserved_pct = self._to_pct(max_reserved)\n\n        num_retries = cuda_info[\"num_alloc_retries\"]\n        num_ooms = cuda_info[\"num_ooms\"]\n\n        if num_retries > 0:\n            self.logger.warning(f\"{num_retries} CUDA memory allocation retries.\")\n        if num_ooms > 0:\n            self.logger.warning(f\"{num_ooms} CUDA OOM errors thrown.\")\n\n        return GPUMemStats(\n            max_active_gib,\n            max_active_pct,\n            max_reserved_gib,\n            max_reserved_pct,\n            num_retries,\n            num_ooms,\n        )\n\n    def reset_peak_stats(self):\n        torch.cuda.reset_peak_memory_stats()\n\n\ndef build_gpu_memory_monitor(logger):\n    gpu_memory_monitor = GPUMemoryMonitor(logger, \"cuda\")\n    logger.info(\n        f\"GPU capacity: {gpu_memory_monitor.device_name} ({gpu_memory_monitor.device_index}) \"\n        f\"with {gpu_memory_monitor.device_capacity_gib:.2f}GiB memory\"\n    )\n\n    return gpu_memory_monitor"
  },
  {
    "path": "nit/utils/lr_scheduler.py",
    "content": "from torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"PyTorch optimization for diffusion models.\"\"\"\n\nimport math\nfrom enum import Enum\nfrom typing import Optional, Union\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\nclass SchedulerType(Enum):\n    LINEAR = \"linear\"\n    COSINE = \"cosine\"\n    COSINE_WITH_RESTARTS = \"cosine_with_restarts\"\n    POLYNOMIAL = \"polynomial\"\n    CONSTANT = \"constant\"\n    CONSTANT_WITH_WARMUP = \"constant_with_warmup\"\n    PIECEWISE_CONSTANT = \"piecewise_constant\"\n    WARMDUP_STABLE_DECAY = \"warmup_stable_decay\"\n\n\ndef get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):\n    \"\"\"\n    Create a schedule with a constant learning rate, using the learning rate set in optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n    return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)\n\ndef get_constant_schedule_with_warmup(\n    optimizer: Optimizer, num_warmup_steps: int, div_factor: int = 1e-4, last_epoch: int = -1\n):\n    def lr_lambda(current_step):\n        # 0,y0 step,y1\n        #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1\n        if current_step < num_warmup_steps:\n            return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor\n        return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\ndef get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):\n    \"\"\"\n    Create a schedule with a constant learning rate, using the learning rate set in optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        step_rules (`string`):\n            The rules for the learning rate. ex: rule_steps=\"1:10,0.1:20,0.01:30,0.005\" it means that the learning rate\n            if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30\n            steps and multiple 0.005 for the other steps.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    rules_dict = {}\n    rule_list = step_rules.split(\",\")\n    for rule_str in rule_list[:-1]:\n        value_str, steps_str = rule_str.split(\":\")\n        steps = int(steps_str)\n        value = float(value_str)\n        rules_dict[steps] = value\n    last_lr_multiple = float(rule_list[-1])\n\n    def create_rules_function(rules_dict, last_lr_multiple):\n        def rule_func(steps: int) -> float:\n            sorted_steps = sorted(rules_dict.keys())\n            for i, sorted_step in enumerate(sorted_steps):\n                if steps < sorted_step:\n                    return rules_dict[sorted_steps[i]]\n            return last_lr_multiple\n\n        return rule_func\n\n    rules_func = create_rules_function(rules_dict, last_lr_multiple)\n\n    return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)\n\n\ndef get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):\n    \"\"\"\n    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after\n    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    def lr_lambda(current_step: int):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        return max(\n            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))\n        )\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_cosine_schedule_with_warmup(\n    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n    initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        num_periods (`float`, *optional*, defaults to 0.5):\n            The number of periods of the cosine function in a schedule (the default is to just decrease from the max\n            value to 0 following a half-cosine).\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_cosine_with_hard_restarts_schedule_with_warmup(\n    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases following the values of the cosine function between the\n    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases\n    linearly between 0 and the initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        num_cycles (`int`, *optional*, defaults to 1):\n            The number of hard restarts to use.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n    \"\"\"\n\n    def lr_lambda(current_step):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))\n        if progress >= 1.0:\n            return 0.0\n        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_polynomial_decay_schedule_with_warmup(\n    optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1\n):\n    \"\"\"\n    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the\n    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the\n    initial lr set in the optimizer.\n\n    Args:\n        optimizer ([`~torch.optim.Optimizer`]):\n            The optimizer for which to schedule the learning rate.\n        num_warmup_steps (`int`):\n            The number of steps for the warmup phase.\n        num_training_steps (`int`):\n            The total number of training steps.\n        lr_end (`float`, *optional*, defaults to 1e-7):\n            The end LR.\n        power (`float`, *optional*, defaults to 1.0):\n            Power factor.\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n\n    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT\n    implementation at\n    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37\n\n    Return:\n        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n\n    \"\"\"\n\n    lr_init = optimizer.defaults[\"lr\"]\n    if not (lr_init > lr_end):\n        raise ValueError(f\"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})\")\n\n    def lr_lambda(current_step: int):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        elif current_step > num_training_steps:\n            return lr_end / lr_init  # as LambdaLR multiplies by lr_init\n        else:\n            lr_range = lr_init - lr_end\n            decay_steps = num_training_steps - num_warmup_steps\n            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n            decay = lr_range * pct_remaining**power + lr_end\n            return decay / lr_init  # as LambdaLR multiplies by lr_init\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\n\ndef get_constant_schedule_with_warmup_and_decay(\n    optimizer: Optimizer, num_warmup_steps: int, num_decay_steps: int, decay_T: int = 50000, div_factor: int = 1e-4, last_epoch: int = -1\n):\n    def lr_lambda(current_step):\n        # 0,y0 step,y1\n        #((y1-y0) * x/step + y0) / y1 = (y1-y0)/y1 * x/step + y0/y1\n        if current_step < num_warmup_steps:\n            return (1 - div_factor) * float(current_step) / float(max(1, num_warmup_steps)) + div_factor\n        if current_step > num_decay_steps:\n            return 0.5 ** ((current_step - num_decay_steps) / decay_T)\n        return 1.0\n\n    return LambdaLR(optimizer, lr_lambda, last_epoch)\n\nTYPE_TO_SCHEDULER_FUNCTION = {\n    SchedulerType.LINEAR: get_linear_schedule_with_warmup,\n    SchedulerType.COSINE: get_cosine_schedule_with_warmup,\n    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,\n    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,\n    SchedulerType.CONSTANT: get_constant_schedule,\n    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,\n    SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,\n    SchedulerType.WARMDUP_STABLE_DECAY: get_constant_schedule_with_warmup_and_decay\n}\n\n\n\n\n\n\n\ndef get_scheduler(\n    name: Union[str, SchedulerType],\n    optimizer: Optimizer,\n    step_rules: Optional[str] = None,\n    num_warmup_steps: Optional[int] = None,\n    num_decay_steps: Optional[int] = None,\n    num_training_steps: Optional[int] = None,\n    num_cycles: int = 1,\n    decay_T: Optional[int] = 50000,\n    power: float = 1.0,\n    last_epoch: int = -1,\n):\n    \"\"\"\n    Unified API to get any scheduler from its name.\n\n    Args:\n        name (`str` or `SchedulerType`):\n            The name of the scheduler to use.\n        optimizer (`torch.optim.Optimizer`):\n            The optimizer that will be used during training.\n        step_rules (`str`, *optional*):\n            A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.\n        num_warmup_steps (`int`, *optional*):\n            The number of warmup steps to do. This is not required by all schedulers (hence the argument being\n            optional), the function will raise an error if it's unset and the scheduler type requires it.\n        num_decay_steps (`int`, *optional*):\n            The number of decay steps to do. This is not required by all schedulers (hence the argument being\n            optional), the function will raise an error if it's unset and the scheduler type requires it.\n        num_training_steps (`int``, *optional*):\n            The number of training steps to do. This is not required by all schedulers (hence the argument being\n            optional), the function will raise an error if it's unset and the scheduler type requires it.\n        num_cycles (`int`, *optional*):\n            The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.\n        power (`float`, *optional*, defaults to 1.0):\n            Power factor. See `POLYNOMIAL` scheduler\n        decay_T (`int`, *optional*, defaults to 50000):\n            Power factor. See `POLYNOMIAL` scheduler\n        last_epoch (`int`, *optional*, defaults to -1):\n            The index of the last epoch when resuming training.\n    \"\"\"\n    name = SchedulerType(name)\n    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]\n    if name == SchedulerType.CONSTANT:\n        return schedule_func(optimizer, last_epoch=last_epoch)\n\n    if name == SchedulerType.PIECEWISE_CONSTANT:\n        return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)\n\n    # All other schedulers require `num_warmup_steps`\n    if num_warmup_steps is None:\n        raise ValueError(f\"{name} requires `num_warmup_steps`, please provide that argument.\")\n\n    if name == SchedulerType.CONSTANT_WITH_WARMUP:\n        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)\n\n    if name == SchedulerType.WARMDUP_STABLE_DECAY:\n        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_decay_steps=num_decay_steps, decay_T=decay_T, last_epoch=last_epoch)\n\n    # All other schedulers require `num_training_steps`\n    if num_training_steps is None:\n        raise ValueError(f\"{name} requires `num_training_steps`, please provide that argument.\")\n\n    if name == SchedulerType.COSINE_WITH_RESTARTS:\n        return schedule_func(\n            optimizer,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            num_cycles=num_cycles,\n            last_epoch=last_epoch,\n        )\n\n    if name == SchedulerType.POLYNOMIAL:\n        return schedule_func(\n            optimizer,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            power=power,\n            last_epoch=last_epoch,\n        )\n\n    return schedule_func(\n        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch\n    )\n"
  },
  {
    "path": "nit/utils/misc_utils.py",
    "content": "import functools\nimport importlib\nimport os\nimport wandb\nimport fsspec\nimport numpy as np\nimport torch\n\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\nfrom safetensors.torch import load_file as load_safetensors\n\n\ndef get_dtype(str_dtype):\n    if str_dtype == 'fp16':\n        return torch.float16\n    elif str_dtype == 'bf16':\n        return torch.bfloat16\n    else:\n        return torch.float32\n    \n    \ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef get_string_from_tuple(s):\n    try:\n        # Check if the string starts and ends with parentheses\n        if s[0] == \"(\" and s[-1] == \")\":\n            # Convert the string to a tuple\n            t = eval(s)\n            # Check if the type of t is tuple\n            if type(t) == tuple:\n                return t[0]\n            else:\n                pass\n    except:\n        pass\n    return s\n\n\ndef is_power_of_two(n):\n    \"\"\"\n    chat.openai.com/chat\n    Return True if n is a power of 2, otherwise return False.\n\n    The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.\n    The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.\n    If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.\n    Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.\n\n    \"\"\"\n    if n <= 0:\n        return False\n    return (n & (n - 1)) == 0\n\n\ndef autocast(f, enabled=True):\n    def do_autocast(*args, **kwargs):\n        with torch.cuda.amp.autocast(\n            enabled=enabled,\n            dtype=torch.get_autocast_gpu_dtype(),\n            cache_enabled=torch.is_autocast_cache_enabled(),\n        ):\n            return f(*args, **kwargs)\n\n    return do_autocast\n\n\ndef load_partial_from_config(config):\n    return partial(get_obj_from_str(config[\"target\"]), **config.get(\"params\", dict()))\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype(\"data/DejaVuSans.ttf\", size=size)\n        nc = int(40 * (wh[0] / 256))\n        if isinstance(xc[bi], list):\n            text_seq = xc[bi][0]\n        else:\n            text_seq = xc[bi]\n        lines = \"\\n\".join(\n            text_seq[start : start + nc] for start in range(0, len(text_seq), nc)\n        )\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef partialclass(cls, *args, **kwargs):\n    class NewCls(cls):\n        __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)\n\n    return NewCls\n\n\ndef make_path_absolute(path):\n    fs, p = fsspec.core.url_to_fs(path)\n    if fs.protocol == \"file\":\n        return os.path.abspath(p)\n    return path\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef isheatmap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n\n    return x.ndim == 2\n\n\ndef isneighbors(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef expand_dims_like(x, y):\n    while x.dim() != y.dim():\n        x = x.unsqueeze(-1)\n    return x\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == \"__is_first_stage__\":\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False, invalidate_cache=True):\n    module, cls = string.rsplit(\".\", 1)\n    if invalidate_cache:\n        importlib.invalidate_caches()\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef append_zero(x):\n    return torch.cat([x, x.new_zeros([1])])\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(\n            f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\"\n        )\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef load_model_from_config(config, ckpt, verbose=True, freeze=True):\n    print(f\"Loading model from {ckpt}\")\n    if ckpt.endswith(\"ckpt\"):\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n        if \"global_step\" in pl_sd:\n            print(f\"Global Step: {pl_sd['global_step']}\")\n        sd = pl_sd[\"state_dict\"]\n    elif ckpt.endswith(\"safetensors\"):\n        sd = load_safetensors(ckpt)\n    elif ckpt.endswith(\"bin\"):\n        sd = torch.load(ckpt, map_location=\"cpu\")\n    else:\n        raise NotImplementedError\n\n    model = instantiate_from_config(config.model)\n\n    m, u = model.load_state_dict(sd, strict=False)\n\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    if freeze:\n        for param in model.parameters():\n            param.requires_grad = False\n\n    model.eval()\n    return model\n\n\ndef format_number(num):\n    num = float(num)\n    num /= 1000.0\n    return '{:.0f}{}'.format(num, 'k')\n\ndef get_num_params(model: torch.nn.ModuleList) -> int:\n    num_params = sum(p.numel() for p in model.parameters())\n    return num_params\n\n\ndef get_num_flop_per_token(num_params, model_config, seq_len) -> int:\n    l, h, q, t = (\n        model_config.n_layers,\n        model_config.n_heads,\n        model_config.dim // model_config.n_heads,\n        seq_len,\n    )\n    # Reasoning behind the factor of 12 for the self-attention part of the formula:\n    # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)\n    # 2. the flash attention does 1 more matmul recomputation in the backward\n    #    but recomputation should not be counted in calculating MFU           (+0)\n    # 3. each matmul performs 1 multiplication and 1 addition                 (*2)\n    # 4. we follow the convention and do not account for sparsity in causal attention\n    flop_per_token = 6 * num_params + 12 * l * h * q * t\n\n    return flop_per_token\n\ndef get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:\n    l, h, q = (\n        model_config.n_layers,\n        model_config.n_heads,\n        model_config.dim // model_config.n_heads,\n    )\n    \n    # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)\n    # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)\n    # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t\n    flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len\n\n    return flop_per_sequence\n\n\n# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU\ndef get_peak_flops(device_name: str) -> int:\n    if \"A100\" in device_name:\n        # data from https://www.nvidia.com/en-us/data-center/a100/\n        return 312e12\n    elif \"H100\" in device_name:\n        # data from https://www.nvidia.com/en-us/data-center/h100/\n        # NOTE: Specifications are one-half lower without sparsity.\n        if \"NVL\" in device_name:\n            return 1979e12\n        elif \"PCIe\" in device_name:\n            return 756e12\n        else:  # for SXM and other variants\n            return 989e12\n    else:  # for other GPU types, assume A100\n        return 312e12\n\n@dataclass(frozen=True)\nclass Color:\n    black = \"\\033[30m\"\n    red = \"\\033[31m\"\n    green = \"\\033[32m\"\n    yellow = \"\\033[33m\"\n    blue = \"\\033[34m\"\n    magenta = \"\\033[35m\"\n    cyan = \"\\033[36m\"\n    white = \"\\033[37m\"\n    reset = \"\\033[39m\"\n\n\n@dataclass(frozen=True)\nclass NoColor:\n    black = \"\"\n    red = \"\"\n    green = \"\"\n    yellow = \"\"\n    blue = \"\"\n    magenta = \"\"\n    cyan = \"\"\n    white = \"\"\n    reset = \"\""
  },
  {
    "path": "nit/utils/model_utils.py",
    "content": "import os\nimport torch\nfrom transformers import T5EncoderModel, AutoModelForCausalLM, AutoTokenizer\n\n\n\n\n# dc-ae\ndef dc_ae_encode(dc_ae, images):\n    with torch.no_grad():\n        latents = dc_ae.encode(images).latent * dc_ae.config.scaling_factor\n    return latents\n\ndef dc_ae_decode(dc_ae, latents):\n    with torch.no_grad():\n        z = latents / dc_ae.config.scaling_factor\n        if dc_ae.use_slicing and z.size(0) > 1:\n            decoded_slices = [dc_ae._decode(z_slice) for z_slice in z.split(1)]\n            decoded = torch.cat(decoded_slices)\n        else:\n            decoded = dc_ae._decode(z)\n        images = decoded    # decoded images\n    return images\n\n# sd-vae\ndef sd_vae_encode(sd_vae, images):\n    with torch.no_grad():\n        z = sd_vae.encode(images)\n        if isinstance(z, dict):\n            z=z.latent_dist.sample()\n        z = sd_vae.config.scaling_factor * z\n    return z\n\ndef sd_vae_decode(sd_vae, latents):\n    with torch.no_grad():\n        z = 1.0 / sd_vae.config.scaling_factor * latents\n        out = sd_vae.decode(z)\n        if isinstance(out, dict):\n            out=out.sample\n    return out\n\n\n\n\n# load text-encoder\ndef load_text_encoder(text_encoder_dir, device, weight_dtype):\n    \n    os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n    tokenizer = AutoTokenizer.from_pretrained(text_encoder_dir)\n    if 'gemma' in text_encoder_dir:\n        tokenizer.padding_side = \"right\"\n        text_encoder = AutoModelForCausalLM.from_pretrained(\n            text_encoder_dir, attn_implementation=\"flash_attention_2\", device_map='cpu', torch_dtype=weight_dtype\n        ).get_decoder()\n    elif 't5' in text_encoder_dir:\n        text_encoder = T5EncoderModel.from_pretrained(\n            text_encoder_dir, attn_implementation=\"sdpa\", device_map='cpu', torch_dtype=weight_dtype\n        )\n    else: \n        raise NotImplementedError\n    text_encoder.requires_grad_(False)\n    text_encoder = text_encoder.eval().to(device=device, dtype=weight_dtype)\n    \n    return text_encoder, tokenizer\n    \ndef encode_prompt(tokenizer, text_encoder, device, weight_dtype, captions, use_last_hidden_state, max_seq_length=256):\n    text_inputs = tokenizer(\n        captions,\n        padding='max_length',\n        max_length=max_seq_length,\n        truncation=True,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids.to(device)\n    prompt_masks = text_inputs.attention_mask.to(device)\n    with torch.no_grad(), torch.autocast(\"cuda\", dtype=weight_dtype):\n        results = text_encoder(\n            input_ids=text_input_ids,\n            attention_mask=prompt_masks,\n            output_hidden_states=True,\n        )\n\n        if use_last_hidden_state:\n            prompt_embeds = results.last_hidden_state\n        else:   # from Imagen paper\n            prompt_embeds = results.hidden_states[-2]\n\n    return prompt_embeds, prompt_masks\n\n\ndef prepare_null_cap_feat_mask(text_encoder_type, device, weight_dtype, use_last_hidden_state, max_seq_length=256):\n    text_encoder, tokenizer = load_text_encoder(\n        text_encoder_dir=text_encoder_type, device=device, weight_dtype=weight_dtype\n    )\n    null_cap_features, null_cap_mask = encode_prompt(\n        tokenizer, text_encoder, device, weight_dtype, \n        \"\", use_last_hidden_state, max_seq_length=max_seq_length\n    )\n    return null_cap_features, null_cap_mask"
  },
  {
    "path": "nit/utils/train_utils.py",
    "content": "import torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom diffusers.utils import logging\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\ndef freeze_model(model, trainable_modules={}, verbose=False):\n    logger.info(\"Start freeze\")\n    for name, param in model.named_parameters():\n        param.requires_grad = False\n        if verbose:\n            logger.info(\"freeze moduel: \"+str(name))\n        for trainable_module_name in trainable_modules:\n            if trainable_module_name in name:\n                param.requires_grad = True\n                if verbose:\n                    logger.info(\"unfreeze moduel: \"+str(name))\n                break\n    logger.info(\"End freeze\")\n    params_unfreeze = [p.numel() if p.requires_grad == True else 0 for n, p in model.named_parameters()]\n    params_freeze = [p.numel() if p.requires_grad == False else 0 for n, p in model.named_parameters()]\n    logger.info(f\"Unfreeze Module Parameters: {sum(params_unfreeze) / 1e6} M\")\n    logger.info(f\"Freeze Module Parameters: {sum(params_freeze) / 1e6} M\")\n    return \n\n\n@torch.no_grad()\ndef update_ema(ema_model, model, decay=0.9999):\n    \"\"\"\n    Step the EMA model towards the current model.\n    \"\"\"\n    if hasattr(model, 'module'):\n        model = model.module\n    if hasattr(ema_model, 'module'):\n        ema_model = ema_model.module\n    ema_params = OrderedDict(ema_model.named_parameters())\n    model_params = OrderedDict(model.named_parameters())\n    \n    for name, param in model_params.items():\n        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed\n        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)\n\n\n\ndef log_validation(model):\n    pass"
  },
  {
    "path": "nit/utils/util.py",
    "content": "import functools\nimport importlib\nimport os\nimport wandb\nimport fsspec\nimport numpy as np\nimport torch\n\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\nfrom safetensors.torch import load_file as load_safetensors\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef get_string_from_tuple(s):\n    try:\n        # Check if the string starts and ends with parentheses\n        if s[0] == \"(\" and s[-1] == \")\":\n            # Convert the string to a tuple\n            t = eval(s)\n            # Check if the type of t is tuple\n            if type(t) == tuple:\n                return t[0]\n            else:\n                pass\n    except:\n        pass\n    return s\n\n\ndef is_power_of_two(n):\n    \"\"\"\n    chat.openai.com/chat\n    Return True if n is a power of 2, otherwise return False.\n\n    The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.\n    The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.\n    If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.\n    Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.\n\n    \"\"\"\n    if n <= 0:\n        return False\n    return (n & (n - 1)) == 0\n\n\ndef autocast(f, enabled=True):\n    def do_autocast(*args, **kwargs):\n        with torch.cuda.amp.autocast(\n            enabled=enabled,\n            dtype=torch.get_autocast_gpu_dtype(),\n            cache_enabled=torch.is_autocast_cache_enabled(),\n        ):\n            return f(*args, **kwargs)\n\n    return do_autocast\n\n\ndef load_partial_from_config(config):\n    return partial(get_obj_from_str(config[\"target\"]), **config.get(\"params\", dict()))\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype(\"data/DejaVuSans.ttf\", size=size)\n        nc = int(40 * (wh[0] / 256))\n        if isinstance(xc[bi], list):\n            text_seq = xc[bi][0]\n        else:\n            text_seq = xc[bi]\n        lines = \"\\n\".join(\n            text_seq[start : start + nc] for start in range(0, len(text_seq), nc)\n        )\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef partialclass(cls, *args, **kwargs):\n    class NewCls(cls):\n        __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)\n\n    return NewCls\n\n\ndef make_path_absolute(path):\n    fs, p = fsspec.core.url_to_fs(path)\n    if fs.protocol == \"file\":\n        return os.path.abspath(p)\n    return path\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef isheatmap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n\n    return x.ndim == 2\n\n\ndef isneighbors(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef expand_dims_like(x, y):\n    while x.dim() != y.dim():\n        x = x.unsqueeze(-1)\n    return x\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == \"__is_first_stage__\":\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False, invalidate_cache=True):\n    module, cls = string.rsplit(\".\", 1)\n    if invalidate_cache:\n        importlib.invalidate_caches()\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef append_zero(x):\n    return torch.cat([x, x.new_zeros([1])])\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    if dims_to_append < 0:\n        raise ValueError(\n            f\"input has {x.ndim} dims but target_dims is {target_dims}, which is less\"\n        )\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef load_model_from_config(config, ckpt, verbose=True, freeze=True):\n    print(f\"Loading model from {ckpt}\")\n    if ckpt.endswith(\"ckpt\"):\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n        if \"global_step\" in pl_sd:\n            print(f\"Global Step: {pl_sd['global_step']}\")\n        sd = pl_sd[\"state_dict\"]\n    elif ckpt.endswith(\"safetensors\"):\n        sd = load_safetensors(ckpt)\n    elif ckpt.endswith(\"bin\"):\n        sd = torch.load(ckpt, map_location=\"cpu\")\n    else:\n        raise NotImplementedError\n\n    model = instantiate_from_config(config.model)\n\n    m, u = model.load_state_dict(sd, strict=False)\n\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    if freeze:\n        for param in model.parameters():\n            param.requires_grad = False\n\n    model.eval()\n    return model\n\n\ndef format_number(num):\n    num = float(num)\n    num /= 1000.0\n    return '{:.0f}{}'.format(num, 'k')\n\ndef get_num_params(model: torch.nn.ModuleList) -> int:\n    num_params = sum(p.numel() for p in model.parameters())\n    return num_params\n\n\ndef get_num_flop_per_token(num_params, model_config, seq_len) -> int:\n    l, h, q, t = (\n        model_config.n_layers,\n        model_config.n_heads,\n        model_config.dim // model_config.n_heads,\n        seq_len,\n    )\n    # Reasoning behind the factor of 12 for the self-attention part of the formula:\n    # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)\n    # 2. the flash attention does 1 more matmul recomputation in the backward\n    #    but recomputation should not be counted in calculating MFU           (+0)\n    # 3. each matmul performs 1 multiplication and 1 addition                 (*2)\n    # 4. we follow the convention and do not account for sparsity in causal attention\n    flop_per_token = 6 * num_params + 12 * l * h * q * t\n\n    return flop_per_token\n\ndef get_num_flop_per_sequence_encoder_only(num_params, model_config, seq_len) -> int:\n    l, h, q = (\n        model_config.n_layers,\n        model_config.n_heads,\n        model_config.dim // model_config.n_heads,\n    )\n    \n    # 1. 每个自注意力层有2个矩阵乘法在前向传播,4个在反向传播 (6)\n    # 2. 每个矩阵乘法执行1次乘法和1次加法 (*2)\n    # 3. 双向注意力需要考虑所有token对,所以是t^2而不是t\n    flop_per_sequence = 6 * num_params + 12 * l * h * q * seq_len * seq_len\n\n    return flop_per_sequence\n\n\n# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU\ndef get_peak_flops(device_name: str) -> int:\n    if \"A100\" in device_name:\n        # data from https://www.nvidia.com/en-us/data-center/a100/\n        return 312e12\n    elif \"H100\" in device_name:\n        # data from https://www.nvidia.com/en-us/data-center/h100/\n        # NOTE: Specifications are one-half lower without sparsity.\n        if \"NVL\" in device_name:\n            return 1979e12\n        elif \"PCIe\" in device_name:\n            return 756e12\n        else:  # for SXM and other variants\n            return 989e12\n    else:  # for other GPU types, assume A100\n        return 312e12\n\n@dataclass(frozen=True)\nclass Color:\n    black = \"\\033[30m\"\n    red = \"\\033[31m\"\n    green = \"\\033[32m\"\n    yellow = \"\\033[33m\"\n    blue = \"\\033[34m\"\n    magenta = \"\\033[35m\"\n    cyan = \"\\033[36m\"\n    white = \"\\033[37m\"\n    reset = \"\\033[39m\"\n\n\n@dataclass(frozen=True)\nclass NoColor:\n    black = \"\"\n    red = \"\"\n    green = \"\"\n    yellow = \"\"\n    blue = \"\"\n    magenta = \"\"\n    cyan = \"\"\n    white = \"\"\n    reset = \"\""
  },
  {
    "path": "nit/utils/video_utils.py",
    "content": "import os\nimport cv2\nimport numpy as np\nfrom PIL import Image\n\ndef save_video_as_mp4(video_array, fps, output_path):\n    # video_array: TCHW (RGB)\n    height, width = video_array.shape[2], video_array.shape[3]\n    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n    for t in range(video_array.shape[0]):\n        frame = video_array[t].transpose(1, 2, 0)\n        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # RGB->BGR\n        out.write(cv2.convertScaleAbs(frame))\n    out.release()\n\ndef save_video_as_png(video_array, output_path):\n    os.makedirs(output_path, exist_ok=True)\n    # video_array: TCHW (RGB)\n    for i, sample in enumerate(video_array):\n        sample = np.transpose(sample, (1, 2, 0))\n        Image.fromarray(sample).save( # HWC\n            os.path.join(output_path, f\"{i:06d}.png\")\n        )\n"
  },
  {
    "path": "nit/utils/warp_pos_idx.py",
    "content": "import torch\nimport random\nfrom typing import Optional, Union\n\n\ndef warp_pos_idx_from_grid(\n    grid: torch.Tensor, \n    shift: Optional[int] = 0, \n    scale: Optional[str] = None, \n    max_len: Optional[Union[int, float]]=None\n):\n    '''\n    grid: the 2-D positional index to be warped (B, 2, D)\n    shift: the max shift value for the positional indices\n    scale: the scale scheme for warping positional indices\n    max_len: the max scale length\n    '''\n    grid[:, 0] = warp_pos_idx(grid[:, 0], shift, scale, max_len)\n    grid[:, 1] = warp_pos_idx(grid[:, 1], shift, scale, max_len)\n    return grid\n    \n\n\ndef warp_pos_idx(\n    pos_idx: torch.Tensor, \n    shift: Optional[int] = 0, \n    scale: Optional[str] = None, \n    max_len: Optional[Union[int, float]]=None\n):\n    '''\n    pos_idx: the 1-D positional index to be warped (B, D)\n    shift: the max shift value for the positional indices\n    scale: the scale scheme for warping positional indices\n    max_len: the max scale length\n    '''\n    if scale != None:\n        assert isinstance(scale, str) and isinstance(max_len, (int, float))\n        if scale.lower() == 'linear':\n            pos_idx = max_len * (pos_idx / pos_idx.max())\n        elif scale.lower() == 'sqrt':\n            pos_idx = max_len * torch.sqrt(pos_idx / max_len)\n        elif scale.lower() in ['sine', 'cosine', 'sin', 'cos']:\n            pos_idx = max_len * torch.sin(pos_idx / max_len * (torch.pi/2))\n        else:\n            raise NotImplementedError('Only support linear, cosine, beta scale scheme for warping')\n        \n    pos_idx = pos_idx + random.randint(0, shift)\n    \n    return pos_idx\n\n"
  },
  {
    "path": "projects/evaluate/adm_evaluator.py",
    "content": "import argparse\nimport io\nimport os\nimport random\nimport warnings\nimport zipfile\nfrom abc import ABC, abstractmethod\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom multiprocessing import cpu_count\nfrom multiprocessing.pool import ThreadPool\nfrom typing import Iterable, Optional, Tuple\nfrom PIL import Image\n\nimport numpy as np\nimport requests\nimport tensorflow.compat.v1 as tf\nfrom scipy import linalg\nfrom tqdm.auto import tqdm\n\nINCEPTION_V3_URL = \"https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb\"\nINCEPTION_V3_PATH = \"checkpoints/classify_image_graph_def.pb\"\n\nFID_POOL_NAME = \"pool_3:0\"\nFID_SPATIAL_NAME = \"mixed_6/conv:0\"\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"ref_batch\", help=\"path to reference batch npz file\")\n    parser.add_argument(\"sample_batch\", help=\"path to sample batch npz file\")\n    args = parser.parse_args()\n\n    config = tf.ConfigProto(\n        allow_soft_placement=True  # allows DecodeJpeg to run on CPU in Inception graph\n    )\n    config.gpu_options.allow_growth = True\n    evaluator = Evaluator(tf.Session(config=config))\n\n    print(\"warming up TensorFlow...\")\n    # This will cause TF to print a bunch of verbose stuff now rather\n    # than after the next print(), to help prevent confusion.\n    evaluator.warmup()\n\n    print(\"computing reference batch activations...\")\n    ref_acts = evaluator.read_activations(args.ref_batch)\n    print(\"computing/reading reference batch statistics...\")\n    ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)\n\n    print(\"computing sample batch activations...\")\n    sample_acts = evaluator.read_activations(args.sample_batch)\n    print(\"computing/reading sample batch statistics...\")\n    sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)\n\n    print(\"Computing evaluations...\")\n    print(\"Inception Score:\", evaluator.compute_inception_score(sample_acts[0]))\n    print(\"FID:\", sample_stats.frechet_distance(ref_stats))\n    print(\"sFID:\", sample_stats_spatial.frechet_distance(ref_stats_spatial))\n    prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])\n    print(\"Precision:\", prec)\n    print(\"Recall:\", recall)\n\n\nclass InvalidFIDException(Exception):\n    pass\n\n\nclass FIDStatistics:\n    def __init__(self, mu: np.ndarray, sigma: np.ndarray):\n        self.mu = mu\n        self.sigma = sigma\n\n    def frechet_distance(self, other, eps=1e-6):\n        \"\"\"\n        Compute the Frechet distance between two sets of statistics.\n        \"\"\"\n        # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132\n        mu1, sigma1 = self.mu, self.sigma\n        mu2, sigma2 = other.mu, other.sigma\n\n        mu1 = np.atleast_1d(mu1)\n        mu2 = np.atleast_1d(mu2)\n\n        sigma1 = np.atleast_2d(sigma1)\n        sigma2 = np.atleast_2d(sigma2)\n\n        assert (\n            mu1.shape == mu2.shape\n        ), f\"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}\"\n        assert (\n            sigma1.shape == sigma2.shape\n        ), f\"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}\"\n\n        diff = mu1 - mu2\n\n        # product might be almost singular\n        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)\n        if not np.isfinite(covmean).all():\n            msg = (\n                \"fid calculation produces singular product; adding %s to diagonal of cov estimates\"\n                % eps\n            )\n            warnings.warn(msg)\n            offset = np.eye(sigma1.shape[0]) * eps\n            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n\n        # numerical error might give slight imaginary component\n        if np.iscomplexobj(covmean):\n            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n                m = np.max(np.abs(covmean.imag))\n                raise ValueError(\"Imaginary component {}\".format(m))\n            covmean = covmean.real\n\n        tr_covmean = np.trace(covmean)\n\n        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean\n\n\nclass Evaluator:\n    def __init__(\n        self,\n        session,\n        batch_size=64,\n        softmax_batch_size=512,\n    ):\n        self.sess = session\n        self.batch_size = batch_size\n        self.softmax_batch_size = softmax_batch_size\n        self.manifold_estimator = ManifoldEstimator(session)\n        with self.sess.graph.as_default():\n            self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])\n            self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])\n            self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)\n            self.softmax = _create_softmax_graph(self.softmax_input)\n\n    def warmup(self):\n        self.compute_activations(np.zeros([1, 8, 64, 64, 3]))\n\n    def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:\n        if npz_path.endswith('.npz'):\n            with open_npz_array(npz_path, \"arr_0\") as reader:\n                return self.compute_activations(reader.read_batches(self.batch_size))\n        else:\n            preds = []\n            spatial_preds = []\n            files = os.listdir(npz_path)\n            run_iter = int(len(files) / self.batch_size)\n            for i in tqdm(range(run_iter)):\n                samples = []\n                for file in files[i*self.batch_size: (i+1)*self.batch_size]:\n                    try:\n                        sample_pil = Image.open(os.path.join(npz_path, file))\n                        sample_np = np.asarray(sample_pil).astype(np.uint8)\n                        samples.append(sample_np)\n                    except:\n                        print('wrong file', os.path.join(npz_path, file))\n                samples = np.stack(samples)\n                samples = samples.astype(np.float32)\n                pred, spatial_pred = self.sess.run(\n                    [self.pool_features, self.spatial_features], {self.image_input: samples}\n                )\n                preds.append(pred.reshape([pred.shape[0], -1]))\n                spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))\n            return (\n                np.concatenate(preds, axis=0),\n                np.concatenate(spatial_preds, axis=0),\n            )\n\n    def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Compute image features for downstream evals.\n\n        :param batches: a iterator over NHWC numpy arrays in [0, 255].\n        :return: a tuple of numpy arrays of shape [N x X], where X is a feature\n                 dimension. The tuple is (pool_3, spatial).\n        \"\"\"\n        preds = []\n        spatial_preds = []\n        for batch in tqdm(batches):\n            batch = batch.astype(np.float32)\n            pred, spatial_pred = self.sess.run(\n                [self.pool_features, self.spatial_features], {self.image_input: batch}\n            )\n            preds.append(pred.reshape([pred.shape[0], -1]))\n            spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))\n        return (\n            np.concatenate(preds, axis=0),\n            np.concatenate(spatial_preds, axis=0),\n        )\n\n    def read_statistics(\n        self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]\n    ) -> Tuple[FIDStatistics, FIDStatistics]:\n        if npz_path.endswith('.npz'):\n            obj = np.load(npz_path)\n            if \"mu\" in list(obj.keys()):\n                return FIDStatistics(obj[\"mu\"], obj[\"sigma\"]), FIDStatistics(\n                    obj[\"mu_s\"], obj[\"sigma_s\"]\n                )\n        return tuple(self.compute_statistics(x) for x in activations)\n\n    def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:\n        mu = np.mean(activations, axis=0)\n        sigma = np.cov(activations, rowvar=False)\n        return FIDStatistics(mu, sigma)\n\n    def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:\n        softmax_out = []\n        for i in range(0, len(activations), self.softmax_batch_size):\n            acts = activations[i : i + self.softmax_batch_size]\n            softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))\n        preds = np.concatenate(softmax_out, axis=0)\n        # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46\n        scores = []\n        for i in range(0, len(preds), split_size):\n            part = preds[i : i + split_size]\n            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))\n            kl = np.mean(np.sum(kl, 1))\n            scores.append(np.exp(kl))\n        return float(np.mean(scores))\n\n    def compute_prec_recall(\n        self, activations_ref: np.ndarray, activations_sample: np.ndarray\n    ) -> Tuple[float, float]:\n        radii_1 = self.manifold_estimator.manifold_radii(activations_ref)\n        radii_2 = self.manifold_estimator.manifold_radii(activations_sample)\n        pr = self.manifold_estimator.evaluate_pr(\n            activations_ref, radii_1, activations_sample, radii_2\n        )\n        return (float(pr[0][0]), float(pr[1][0]))\n\n\nclass ManifoldEstimator:\n    \"\"\"\n    A helper for comparing manifolds of feature vectors.\n\n    Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57\n    \"\"\"\n\n    def __init__(\n        self,\n        session,\n        row_batch_size=10000,\n        col_batch_size=10000,\n        nhood_sizes=(3,),\n        clamp_to_percentile=None,\n        eps=1e-5,\n    ):\n        \"\"\"\n        Estimate the manifold of given feature vectors.\n\n        :param session: the TensorFlow session.\n        :param row_batch_size: row batch size to compute pairwise distances\n                               (parameter to trade-off between memory usage and performance).\n        :param col_batch_size: column batch size to compute pairwise distances.\n        :param nhood_sizes: number of neighbors used to estimate the manifold.\n        :param clamp_to_percentile: prune hyperspheres that have radius larger than\n                                    the given percentile.\n        :param eps: small number for numerical stability.\n        \"\"\"\n        self.distance_block = DistanceBlock(session)\n        self.row_batch_size = row_batch_size\n        self.col_batch_size = col_batch_size\n        self.nhood_sizes = nhood_sizes\n        self.num_nhoods = len(nhood_sizes)\n        self.clamp_to_percentile = clamp_to_percentile\n        self.eps = eps\n\n    def warmup(self):\n        feats, radii = (\n            np.zeros([1, 2048], dtype=np.float32),\n            np.zeros([1, 1], dtype=np.float32),\n        )\n        self.evaluate_pr(feats, radii, feats, radii)\n\n    def manifold_radii(self, features: np.ndarray) -> np.ndarray:\n        num_images = len(features)\n\n        # Estimate manifold of features by calculating distances to k-NN of each sample.\n        radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)\n        distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)\n        seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)\n\n        for begin1 in range(0, num_images, self.row_batch_size):\n            end1 = min(begin1 + self.row_batch_size, num_images)\n            row_batch = features[begin1:end1]\n\n            for begin2 in range(0, num_images, self.col_batch_size):\n                end2 = min(begin2 + self.col_batch_size, num_images)\n                col_batch = features[begin2:end2]\n\n                # Compute distances between batches.\n                distance_batch[\n                    0 : end1 - begin1, begin2:end2\n                ] = self.distance_block.pairwise_distances(row_batch, col_batch)\n\n            # Find the k-nearest neighbor from the current batch.\n            radii[begin1:end1, :] = np.concatenate(\n                [\n                    x[:, self.nhood_sizes]\n                    for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)\n                ],\n                axis=0,\n            )\n\n        if self.clamp_to_percentile is not None:\n            max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)\n            radii[radii > max_distances] = 0\n        return radii\n\n    def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):\n        \"\"\"\n        Evaluate if new feature vectors are at the manifold.\n        \"\"\"\n        num_eval_images = eval_features.shape[0]\n        num_ref_images = radii.shape[0]\n        distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)\n        batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)\n        max_realism_score = np.zeros([num_eval_images], dtype=np.float32)\n        nearest_indices = np.zeros([num_eval_images], dtype=np.int32)\n\n        for begin1 in range(0, num_eval_images, self.row_batch_size):\n            end1 = min(begin1 + self.row_batch_size, num_eval_images)\n            feature_batch = eval_features[begin1:end1]\n\n            for begin2 in range(0, num_ref_images, self.col_batch_size):\n                end2 = min(begin2 + self.col_batch_size, num_ref_images)\n                ref_batch = features[begin2:end2]\n\n                distance_batch[\n                    0 : end1 - begin1, begin2:end2\n                ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)\n\n            # From the minibatch of new feature vectors, determine if they are in the estimated manifold.\n            # If a feature vector is inside a hypersphere of some reference sample, then\n            # the new sample lies at the estimated manifold.\n            # The radii of the hyperspheres are determined from distances of neighborhood size k.\n            samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii\n            batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)\n\n            max_realism_score[begin1:end1] = np.max(\n                radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1\n            )\n            nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)\n\n        return {\n            \"fraction\": float(np.mean(batch_predictions)),\n            \"batch_predictions\": batch_predictions,\n            \"max_realisim_score\": max_realism_score,\n            \"nearest_indices\": nearest_indices,\n        }\n\n    def evaluate_pr(\n        self,\n        features_1: np.ndarray,\n        radii_1: np.ndarray,\n        features_2: np.ndarray,\n        radii_2: np.ndarray,\n    ) -> Tuple[np.ndarray, np.ndarray]:\n        \"\"\"\n        Evaluate precision and recall efficiently.\n\n        :param features_1: [N1 x D] feature vectors for reference batch.\n        :param radii_1: [N1 x K1] radii for reference vectors.\n        :param features_2: [N2 x D] feature vectors for the other batch.\n        :param radii_2: [N x K2] radii for other vectors.\n        :return: a tuple of arrays for (precision, recall):\n                 - precision: an np.ndarray of length K1\n                 - recall: an np.ndarray of length K2\n        \"\"\"\n        features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool)\n        features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool)\n        for begin_1 in range(0, len(features_1), self.row_batch_size):\n            end_1 = begin_1 + self.row_batch_size\n            batch_1 = features_1[begin_1:end_1]\n            for begin_2 in range(0, len(features_2), self.col_batch_size):\n                end_2 = begin_2 + self.col_batch_size\n                batch_2 = features_2[begin_2:end_2]\n                batch_1_in, batch_2_in = self.distance_block.less_thans(\n                    batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]\n                )\n                features_1_status[begin_1:end_1] |= batch_1_in\n                features_2_status[begin_2:end_2] |= batch_2_in\n        return (\n            np.mean(features_2_status.astype(np.float64), axis=0),\n            np.mean(features_1_status.astype(np.float64), axis=0),\n        )\n\n\nclass DistanceBlock:\n    \"\"\"\n    Calculate pairwise distances between vectors.\n\n    Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34\n    \"\"\"\n\n    def __init__(self, session):\n        self.session = session\n\n        # Initialize TF graph to calculate pairwise distances.\n        with session.graph.as_default():\n            self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])\n            self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])\n            distance_block_16 = _batch_pairwise_distances(\n                tf.cast(self._features_batch1, tf.float16),\n                tf.cast(self._features_batch2, tf.float16),\n            )\n            self.distance_block = tf.cond(\n                tf.reduce_all(tf.math.is_finite(distance_block_16)),\n                lambda: tf.cast(distance_block_16, tf.float32),\n                lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),\n            )\n\n            # Extra logic for less thans.\n            self._radii1 = tf.placeholder(tf.float32, shape=[None, None])\n            self._radii2 = tf.placeholder(tf.float32, shape=[None, None])\n            dist32 = tf.cast(self.distance_block, tf.float32)[..., None]\n            self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)\n            self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)\n\n    def pairwise_distances(self, U, V):\n        \"\"\"\n        Evaluate pairwise distances between two batches of feature vectors.\n        \"\"\"\n        return self.session.run(\n            self.distance_block,\n            feed_dict={self._features_batch1: U, self._features_batch2: V},\n        )\n\n    def less_thans(self, batch_1, radii_1, batch_2, radii_2):\n        return self.session.run(\n            [self._batch_1_in, self._batch_2_in],\n            feed_dict={\n                self._features_batch1: batch_1,\n                self._features_batch2: batch_2,\n                self._radii1: radii_1,\n                self._radii2: radii_2,\n            },\n        )\n\n\ndef _batch_pairwise_distances(U, V):\n    \"\"\"\n    Compute pairwise distances between two batches of feature vectors.\n    \"\"\"\n    with tf.variable_scope(\"pairwise_dist_block\"):\n        # Squared norms of each row in U and V.\n        norm_u = tf.reduce_sum(tf.square(U), 1)\n        norm_v = tf.reduce_sum(tf.square(V), 1)\n\n        # norm_u as a column and norm_v as a row vectors.\n        norm_u = tf.reshape(norm_u, [-1, 1])\n        norm_v = tf.reshape(norm_v, [1, -1])\n\n        # Pairwise squared Euclidean distances.\n        D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)\n\n    return D\n\n\nclass NpzArrayReader(ABC):\n    @abstractmethod\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        pass\n\n    @abstractmethod\n    def remaining(self) -> int:\n        pass\n\n    def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:\n        def gen_fn():\n            while True:\n                batch = self.read_batch(batch_size)\n                if batch is None:\n                    break\n                yield batch\n\n        rem = self.remaining()\n        num_batches = rem // batch_size + int(rem % batch_size != 0)\n        return BatchIterator(gen_fn, num_batches)\n\n\nclass BatchIterator:\n    def __init__(self, gen_fn, length):\n        self.gen_fn = gen_fn\n        self.length = length\n\n    def __len__(self):\n        return self.length\n\n    def __iter__(self):\n        return self.gen_fn()\n\n\nclass StreamingNpzArrayReader(NpzArrayReader):\n    def __init__(self, arr_f, shape, dtype):\n        self.arr_f = arr_f\n        self.shape = shape\n        self.dtype = dtype\n        self.idx = 0\n\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        if self.idx >= self.shape[0]:\n            return None\n\n        bs = min(batch_size, self.shape[0] - self.idx)\n        self.idx += bs\n\n        if self.dtype.itemsize == 0:\n            return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)\n\n        read_count = bs * np.prod(self.shape[1:])\n        read_size = int(read_count * self.dtype.itemsize)\n        data = _read_bytes(self.arr_f, read_size, \"array data\")\n        return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])\n\n    def remaining(self) -> int:\n        return max(0, self.shape[0] - self.idx)\n\n\nclass MemoryNpzArrayReader(NpzArrayReader):\n    def __init__(self, arr):\n        self.arr = arr\n        self.idx = 0\n\n    @classmethod\n    def load(cls, path: str, arr_name: str):\n        with open(path, \"rb\") as f:\n            arr = np.load(f)[arr_name]\n        return cls(arr)\n\n    def read_batch(self, batch_size: int) -> Optional[np.ndarray]:\n        if self.idx >= self.arr.shape[0]:\n            return None\n\n        res = self.arr[self.idx : self.idx + batch_size]\n        self.idx += batch_size\n        return res\n\n    def remaining(self) -> int:\n        return max(0, self.arr.shape[0] - self.idx)\n\n\n@contextmanager\ndef open_npz_array(path: str, arr_name: str) -> NpzArrayReader:\n    with _open_npy_file(path, arr_name) as arr_f:\n        version = np.lib.format.read_magic(arr_f)\n        if version == (1, 0):\n            header = np.lib.format.read_array_header_1_0(arr_f)\n        elif version == (2, 0):\n            header = np.lib.format.read_array_header_2_0(arr_f)\n        else:\n            yield MemoryNpzArrayReader.load(path, arr_name)\n            return\n        shape, fortran, dtype = header\n        if fortran or dtype.hasobject:\n            yield MemoryNpzArrayReader.load(path, arr_name)\n        else:\n            yield StreamingNpzArrayReader(arr_f, shape, dtype)\n\n\ndef _read_bytes(fp, size, error_template=\"ran out of data\"):\n    \"\"\"\n    Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886\n\n    Read from file-like object until size bytes are read.\n    Raises ValueError if not EOF is encountered before size bytes are read.\n    Non-blocking objects only supported if they derive from io objects.\n    Required as e.g. ZipExtFile in python 2.6 can return less data than\n    requested.\n    \"\"\"\n    data = bytes()\n    while True:\n        # io files (default in python3) return None or raise on\n        # would-block, python2 file will truncate, probably nothing can be\n        # done about that.  note that regular files can't be non-blocking\n        try:\n            r = fp.read(size - len(data))\n            data += r\n            if len(r) == 0 or len(data) == size:\n                break\n        except io.BlockingIOError:\n            pass\n    if len(data) != size:\n        msg = \"EOF: reading %s, expected %d bytes got %d\"\n        raise ValueError(msg % (error_template, size, len(data)))\n    else:\n        return data\n\n\n@contextmanager\ndef _open_npy_file(path: str, arr_name: str):\n    with open(path, \"rb\") as f:\n        with zipfile.ZipFile(f, \"r\") as zip_f:\n            if f\"{arr_name}.npy\" not in zip_f.namelist():\n                raise ValueError(f\"missing {arr_name} in npz file\")\n            with zip_f.open(f\"{arr_name}.npy\", \"r\") as arr_f:\n                yield arr_f\n\n\ndef _download_inception_model():\n    if os.path.exists(INCEPTION_V3_PATH):\n        return\n    print(\"downloading InceptionV3 model...\")\n    with requests.get(INCEPTION_V3_URL, stream=True) as r:\n        r.raise_for_status()\n        tmp_path = INCEPTION_V3_PATH + \".tmp\"\n        with open(tmp_path, \"wb\") as f:\n            for chunk in tqdm(r.iter_content(chunk_size=8192)):\n                f.write(chunk)\n        os.rename(tmp_path, INCEPTION_V3_PATH)\n\n\ndef _create_feature_graph(input_batch):\n    _download_inception_model()\n    prefix = f\"{random.randrange(2**32)}_{random.randrange(2**32)}\"\n    with open(INCEPTION_V3_PATH, \"rb\") as f:\n        graph_def = tf.GraphDef()\n        graph_def.ParseFromString(f.read())\n    pool3, spatial = tf.import_graph_def(\n        graph_def,\n        input_map={f\"ExpandDims:0\": input_batch},\n        return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],\n        name=prefix,\n    )\n    _update_shapes(pool3)\n    spatial = spatial[..., :7]\n    return pool3, spatial\n\n\ndef _create_softmax_graph(input_batch):\n    _download_inception_model()\n    prefix = f\"{random.randrange(2**32)}_{random.randrange(2**32)}\"\n    with open(INCEPTION_V3_PATH, \"rb\") as f:\n        graph_def = tf.GraphDef()\n        graph_def.ParseFromString(f.read())\n    (matmul,) = tf.import_graph_def(\n        graph_def, return_elements=[f\"softmax/logits/MatMul\"], name=prefix\n    )\n    w = matmul.inputs[1]\n    logits = tf.matmul(input_batch, w)\n    return tf.nn.softmax(logits)\n\n\ndef _update_shapes(pool3):\n    # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63\n    ops = pool3.graph.get_operations()\n    for op in ops:\n        for o in op.outputs:\n            shape = o.get_shape()\n            if shape._dims is not None:  # pylint: disable=protected-access\n                # shape = [s.value for s in shape] TF 1.x\n                shape = [s for s in shape]  # TF 2.x\n                new_shape = []\n                for j, s in enumerate(shape):\n                    if s == 1 and j == 0:\n                        new_shape.append(None)\n                    else:\n                        new_shape.append(s)\n                o.__dict__[\"_shape_val\"] = tf.TensorShape(new_shape)\n    return pool3\n\n\ndef _numpy_partition(arr, kth, **kwargs):\n    num_workers = min(cpu_count(), len(arr))\n    chunk_size = len(arr) // num_workers\n    extra = len(arr) % num_workers\n\n    start_idx = 0\n    batches = []\n    for i in range(num_workers):\n        size = chunk_size + (1 if i < extra else 0)\n        batches.append(arr[start_idx : start_idx + size])\n        start_idx += size\n\n    with ThreadPool(num_workers) as pool:\n        return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "projects/preprocess/image_latent_c2i.py",
    "content": "import os\nimport torch\nimport argparse\nimport datetime\nimport time\nimport torchvision\nimport logging\nimport math\nimport shutil\nimport accelerate\nimport torch\nimport torch.utils.checkpoint\nimport diffusers\nimport numpy as np\nimport torch.nn.functional as F\nimport einops\nimport json\nimport os.path as osp\nimport functools\n\nfrom PIL import Image\nfrom torch.cuda import amp\nfrom torch.utils.data import DataLoader, Dataset\nfrom omegaconf import OmegaConf\nfrom accelerate import Accelerator, skip_first_batches\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom tqdm.auto import tqdm\nfrom diffusers import AutoencoderKL, AutoencoderDC\nfrom nit.utils.misc_utils import instantiate_from_config\nfrom torchvision import transforms\nfrom torchvision.datasets.folder import DatasetFolder, default_loader\nfrom torchvision.transforms.functional import hflip \nfrom safetensors.torch import save_file\nfrom typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union\nfrom nit.utils.model_utils import dc_ae_encode, sd_vae_encode\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n# For Omegaconf Tuple\ndef resolve_tuple(*args):\n    return tuple(args)\nOmegaConf.register_new_resolver(\"tuple\", resolve_tuple)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"\",\n        help=\"The config file for training.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"t2i_linear_attention\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--seed\", \n        type=int, \n        default=None, \n        help=\"A seed for reproducible training.\"\n    )\n    args = parser.parse_args()\n    return args\n\n\n\ndef center_crop_arr(pil_image, image_size):\n    \"\"\"\n    Center cropping implementation from ADM.\n    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126\n    \"\"\"\n    while min(*pil_image.size) >= 2 * image_size:\n        pil_image = pil_image.resize(\n            tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX\n        )\n\n    scale = image_size / min(*pil_image.size)\n    pil_image = pil_image.resize(\n        tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC\n    )\n\n    arr = np.array(pil_image)\n    crop_y = (arr.shape[0] - image_size) // 2\n    crop_x = (arr.shape[1] - image_size) // 2\n    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])\n\nIMG_EXTENSIONS = (\".jpg\", \".jpeg\", \".png\", \".ppm\", \".bmp\", \".pgm\", \".tif\", \".tiff\", \".webp\")\n\nclass ImageFolder(DatasetFolder):\n    \"\"\"A generic data loader where the images are arranged in this way by default: ::\n\n        root/dog/xxx.png\n        root/dog/xxy.png\n        root/dog/[...]/xxz.png\n\n        root/cat/123.png\n        root/cat/nsdf3.png\n        root/cat/[...]/asd932_.png\n\n    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so\n    the same methods can be overridden to customize the dataset.\n\n    Args:\n        root (string): Root directory path.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        loader (callable, optional): A function to load an image given its path.\n        is_valid_file (callable, optional): A function that takes path of an Image file\n            and check if the file is a valid file (used to check of corrupt files)\n\n     Attributes:\n        classes (list): List of the class names sorted alphabetically.\n        class_to_idx (dict): Dict with items (class_name, class_index).\n        imgs (list): List of (image path, class_index) tuples\n    \"\"\"\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        target_transform: Optional[Callable] = None,\n        loader: Callable[[str], Any] = default_loader,\n        is_valid_file: Optional[Callable[[str], bool]] = None,\n    ):\n        super().__init__(\n            root,\n            loader,\n            IMG_EXTENSIONS if is_valid_file is None else None,\n            transform=transform,\n            target_transform=target_transform,\n            is_valid_file=is_valid_file,\n        )\n        self.imgs = self.samples\n    \n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        Returns:\n            tuple: (sample, target) where target is class_index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        sample = self.loader(path)\n        if self.transform is not None:\n            sample = self.transform(sample)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return sample, target, path\n\nclass ImagenetDataDictWrapper(Dataset):\n    def __init__(self, dataset):\n        super().__init__()\n        self.dataset = dataset\n\n    def __getitem__(self, i):\n        x, y, p = self.dataset[i]\n        return {\"jpg\": x, \"cls\": y, \"path\": p}\n\n    def __len__(self):\n        return len(self.dataset)\n\n# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60\ndef get_train_sampler(global_batch_size, max_steps, resume_step):\n    sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long)\n    return sample_indices[resume_step * global_batch_size : ].tolist()\n\n\nclass ImagenetLoader():\n    def __init__(self, data_config):\n        super().__init__()\n\n        self.batch_size = data_config.dataloader.batch_size\n        self.num_workers = data_config.dataloader.num_workers\n\n        transform = transforms.Compose([\n            transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, data_config.dataset.resolution)),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)\n        ])\n        \n        self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform))\n        \n        self.test_dataset = None\n        self.val_dataset = None\n\n    def train_len(self):\n        return len(self.train_dataset)\n\n    def train_dataloader(self, global_batch_size, max_steps, resume_step):\n        sampler = get_train_sampler(\n            global_batch_size, max_steps, resume_step\n        )\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            sampler=sampler,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=True\n        )\n\n    def test_dataloader(self):\n        return None\n\n    def val_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=False\n        )\ndef main(args):\n    project_dir = args.project_dir\n    config = OmegaConf.load(args.config)\n    model_config = config.model \n    data_config = config.data\n    train_config = config.training\n\n    config_dir = osp.join(project_dir, 'configs')\n    checkpoint_dir = osp.join(project_dir, 'checkpoints')\n    logging_dir = osp.join(project_dir, 'logs')\n    sample_dir = osp.join(project_dir, 'samples')\n\n    accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=train_config.gradient_accumulation_steps,\n        mixed_precision=train_config.mixed_precision,\n        log_with=train_config.tracker,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes\n    )\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        os.makedirs(project_dir, exist_ok=True)\n        os.makedirs(config_dir, exist_ok=True)\n        os.makedirs(checkpoint_dir, exist_ok=True)\n        os.makedirs(logging_dir, exist_ok=True)\n        os.makedirs(sample_dir, exist_ok=True)\n        OmegaConf.save(config=config, f=osp.join(config_dir, \"config.yaml\"))\n    \n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    \n    if train_config.allow_tf32: # for A100\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    # Setup models\n    weight_dtype = torch.float32\n    if 'sd-vae' in model_config.vae:    \n        sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)\n        sd_vae.eval()\n        sd_vae.requires_grad_(False)\n        encode_func = functools.partial(sd_vae_encode, sd_vae)\n    elif 'dc-ae' in model_config.vae:\n        dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)\n        dc_ae.eval()\n        dc_ae.requires_grad_(False)\n        encode_func = functools.partial(dc_ae_encode, dc_ae)\n        \n    \n\n    # Setup Dataloader\n    total_batch_size = (\n        data_config.dataloader.batch_size * \n        accelerator.num_processes * \n        train_config.gradient_accumulation_steps\n    )\n    global_steps = 0\n    if train_config.resume_from_checkpoint:\n        # normal read with safety check\n        if train_config.resume_from_checkpoint != \"latest\":\n            resume_from_path = os.path.basename(train_config.resume_from_checkpoint)\n        else:   # Get the most recent checkpoint\n            dirs = os.listdir(checkpoint_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None\n\n        if resume_from_path is None:\n            logger.info(\n                f\"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            train_config.resume_from_checkpoint = None\n        else:\n            global_steps = int(resume_from_path.split(\"-\")[1]) # gs not calculate the gradient_accumulation_steps\n            logger.info(f\"Resuming from steps: {global_steps}\")\n    \n    get_train_dataloader = ImagenetLoader(data_config)\n    train_len = get_train_dataloader.train_len()\n    train_config.max_train_steps = math.ceil(train_len / total_batch_size)\n    train_dataloader = get_train_dataloader.train_dataloader(\n        global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps, \n    )\n\n    \n    # Prepare Accelerate\n    train_dataloader= accelerator.prepare(train_dataloader)    \n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}\")\n    logger.info(f\"  Dataset Length = {get_train_dataloader.train_len()}\")\n    logger.info(f\"  Instantaneous batch size per device = {data_config.dataloader.batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {train_config.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {train_config.max_train_steps}\")\n    \n\n    # Potentially load in the weights and states from a previous save\n    if train_config.resume_from_checkpoint and resume_from_path != None:\n        accelerator.print(f\"Resuming from checkpoint {resume_from_path}\")\n        accelerator.load_state(resume_from_path)\n\n        \n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(\n        range(0, train_config.max_train_steps), \n        disable = not accelerator.is_local_main_process\n    )\n    progress_bar.set_description(\"Optim Steps\")\n    progress_bar.update(global_steps)\n    \n    # prepare patch size and max sequence length\n    # make directory\n    os.makedirs(data_config.dataset.target_dir, exist_ok=True)\n    def save_data(z, y, p):\n        # p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG'\n        # target_folder: 'target_dir/n01440764'\n        target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2])\n        f_name = p.split('/')[-1].split('.')[0]\n        os.makedirs(target_folder, exist_ok=True)\n        single_data = dict(latent=z.contiguous(), label=y)\n        save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors'))\n    \n    for step, batch in enumerate(train_dataloader, start=global_steps):\n        for batch_key in batch.keys():\n            if not isinstance(batch[batch_key], (list, str)):\n                batch[batch_key] = batch[batch_key].to(dtype=weight_dtype)\n            x = batch['jpg']\n            y = batch['cls']\n            p = batch['path']\n            if 'sd-vae' in model_config.vae:\n                z_ori = encode_func(x)\n                z_hflip = encode_func(hflip(x))\n                z = torch.stack([z_ori, z_hflip], dim=1)\n            elif 'dc-ae' in model_config.vae:\n                z_ori = encode_func(x)\n                z_hflip = encode_func(hflip(x))\n                z = torch.stack([z_ori, z_hflip], dim=1)\n            \n            for i in range(len(p)):\n                save_data(z[i], y[i], p[i])\n            \n        # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation\n        if accelerator.sync_gradients: \n            progress_bar.update(1)\n            global_steps += 1\n                \n            if global_steps % train_config.checkpointing_steps == 0:\n                if accelerator.is_main_process:\n                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                    if train_config.checkpoints_total_limit is not None:\n                        checkpoints = os.listdir(checkpoint_dir)\n                        checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                        if len(checkpoints) >= train_config.checkpoints_total_limit:\n                            num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1\n                            removing_checkpoints = checkpoints[0:num_to_remove]\n\n                            logger.info(\n                                f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                            )\n                            logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                            for removing_checkpoint in removing_checkpoints:\n                                removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint)\n                                shutil.rmtree(removing_checkpoint)\n\n                    save_path = os.path.join(checkpoint_dir, f\"checkpoint-{global_steps}\")\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n                    \n                accelerator.wait_for_everyone()\n        \n        if global_steps >= train_config.max_train_steps:\n            break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n\n        \nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)"
  },
  {
    "path": "projects/preprocess/image_nr_latent_c2i.py",
    "content": "import os\nimport torch\nimport argparse\nimport datetime\nimport time\nimport torchvision\nimport logging\nimport math\nimport shutil\nimport accelerate\nimport torch\nimport torch.utils.checkpoint\nimport diffusers\nimport numpy as np\nimport torch.nn.functional as F\nimport einops\nimport json\nimport os.path as osp\nimport functools\n\nfrom PIL import Image\nfrom torch.cuda import amp\nfrom torch.utils.data import DataLoader, Dataset\nfrom omegaconf import OmegaConf\nfrom accelerate import Accelerator, skip_first_batches\nfrom accelerate.logging import get_logger\nfrom accelerate.state import AcceleratorState\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom tqdm.auto import tqdm\nfrom diffusers import AutoencoderKL, AutoencoderDC\nfrom nit.utils.misc_utils import instantiate_from_config\nfrom torchvision import transforms\nfrom torchvision.datasets.folder import DatasetFolder, default_loader\nfrom torchvision.transforms.functional import hflip \nfrom safetensors.torch import save_file\nfrom typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union\nfrom nit.utils.model_utils import dc_ae_encode, sd_vae_encode\n\nlogger = get_logger(__name__, log_level=\"INFO\")\n\n# For Omegaconf Tuple\ndef resolve_tuple(*args):\n    return tuple(args)\nOmegaConf.register_new_resolver(\"tuple\", resolve_tuple)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"\",\n        help=\"The config file for training.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"t2i_linear_attention\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--seed\", \n        type=int, \n        default=None, \n        help=\"A seed for reproducible training.\"\n    )\n    args = parser.parse_args()\n    return args\n\n\n\ndef native_resolution_resize(pil_image, min_image_size, max_image_size):\n    \"\"\"\n    Center cropping implementation from ADM.\n    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126\n    \"\"\"\n    w, h = pil_image.size\n    if w * h < max_image_size**2:\n        new_w = max(1, int(w/min_image_size)) * min_image_size\n        new_h = max(1, int(h/min_image_size)) * min_image_size\n    else:\n        new_w = np.sqrt(w/h) * max_image_size\n        new_h = new_w * h / w        \n        new_w = int(new_w/min_image_size) * min_image_size\n        new_h = int(new_h/min_image_size) * min_image_size\n    pil_image = pil_image.resize((new_w, new_h), resample=Image.Resampling.BICUBIC)\n    return pil_image\n\nIMG_EXTENSIONS = (\".jpg\", \".jpeg\", \".png\", \".ppm\", \".bmp\", \".pgm\", \".tif\", \".tiff\", \".webp\")\n\nclass ImageFolder(DatasetFolder):\n    \"\"\"A generic data loader where the images are arranged in this way by default: ::\n\n        root/dog/xxx.png\n        root/dog/xxy.png\n        root/dog/[...]/xxz.png\n\n        root/cat/123.png\n        root/cat/nsdf3.png\n        root/cat/[...]/asd932_.png\n\n    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so\n    the same methods can be overridden to customize the dataset.\n\n    Args:\n        root (string): Root directory path.\n        transform (callable, optional): A function/transform that  takes in an PIL image\n            and returns a transformed version. E.g, ``transforms.RandomCrop``\n        target_transform (callable, optional): A function/transform that takes in the\n            target and transforms it.\n        loader (callable, optional): A function to load an image given its path.\n        is_valid_file (callable, optional): A function that takes path of an Image file\n            and check if the file is a valid file (used to check of corrupt files)\n\n     Attributes:\n        classes (list): List of the class names sorted alphabetically.\n        class_to_idx (dict): Dict with items (class_name, class_index).\n        imgs (list): List of (image path, class_index) tuples\n    \"\"\"\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        target_transform: Optional[Callable] = None,\n        loader: Callable[[str], Any] = default_loader,\n        is_valid_file: Optional[Callable[[str], bool]] = None,\n    ):\n        super().__init__(\n            root,\n            loader,\n            IMG_EXTENSIONS if is_valid_file is None else None,\n            transform=transform,\n            target_transform=target_transform,\n            is_valid_file=is_valid_file,\n        )\n        self.imgs = self.samples\n    \n    def __getitem__(self, index: int) -> Tuple[Any, Any]:\n        \"\"\"\n        Args:\n            index (int): Index\n\n        Returns:\n            tuple: (sample, target) where target is class_index of the target class.\n        \"\"\"\n        path, target = self.samples[index]\n        sample = self.loader(path)\n        if self.transform is not None:\n            sample = self.transform(sample)\n        if self.target_transform is not None:\n            target = self.target_transform(target)\n\n        return sample, target, path\n\nclass ImagenetDataDictWrapper(Dataset):\n    def __init__(self, dataset):\n        super().__init__()\n        self.dataset = dataset\n\n    def __getitem__(self, i):\n        x, y, p = self.dataset[i]\n        return {\"jpg\": x, \"cls\": y, \"path\": p}\n\n    def __len__(self):\n        return len(self.dataset)\n\n# from https://github.com/Alpha-VLLM/LLaMA2-Accessory/blob/main/Large-DiT-ImageNet/train.py#L60\ndef get_train_sampler(global_batch_size, max_steps, resume_step):\n    sample_indices = torch.arange(0, max_steps * global_batch_size,).to(torch.long)\n    return sample_indices[resume_step * global_batch_size : ].tolist()\n\n\nclass ImagenetLoader():\n    def __init__(self, data_config):\n        super().__init__()\n\n        self.batch_size = data_config.dataloader.batch_size\n        self.num_workers = data_config.dataloader.num_workers\n\n        transform = transforms.Compose([\n            transforms.Lambda(lambda pil_image: native_resolution_resize(\n                pil_image, data_config.dataset.min_image_size, data_config.dataset.max_image_size\n            )),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)\n        ])\n        \n        self.train_dataset = ImagenetDataDictWrapper(ImageFolder(data_config.dataset.data_dir, transform=transform))\n        \n        self.test_dataset = None\n        self.val_dataset = None\n\n    def train_len(self):\n        return len(self.train_dataset)\n\n    def train_dataloader(self, global_batch_size, max_steps, resume_step):\n        sampler = get_train_sampler(\n            global_batch_size, max_steps, resume_step\n        )\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            sampler=sampler,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=False\n        )\n\n    def test_dataloader(self):\n        return None\n\n    def val_dataloader(self):\n        return DataLoader(\n            self.train_dataset,\n            batch_size=self.batch_size,\n            shuffle=self.shuffle,\n            num_workers=self.num_workers,\n            pin_memory=True,\n            drop_last=True\n        )\ndef main(args):\n    project_dir = args.project_dir\n    config = OmegaConf.load(args.config)\n    model_config = config.model \n    data_config = config.data\n    train_config = config.training\n\n    config_dir = osp.join(project_dir, 'configs')\n    checkpoint_dir = osp.join(project_dir, 'checkpoints')\n    logging_dir = osp.join(project_dir, 'logs')\n    sample_dir = osp.join(project_dir, 'samples')\n\n    accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=train_config.gradient_accumulation_steps,\n        mixed_precision=train_config.mixed_precision,\n        log_with=train_config.tracker,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes\n    )\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        os.makedirs(project_dir, exist_ok=True)\n        os.makedirs(config_dir, exist_ok=True)\n        os.makedirs(checkpoint_dir, exist_ok=True)\n        os.makedirs(logging_dir, exist_ok=True)\n        os.makedirs(sample_dir, exist_ok=True)\n        OmegaConf.save(config=config, f=osp.join(config_dir, \"config.yaml\"))\n    \n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    \n    if train_config.allow_tf32: # for A100\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    # Setup models\n    weight_dtype = torch.float32\n    if 'sd-vae' in model_config.vae:    \n        sd_vae = AutoencoderKL.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)\n        sd_vae.eval()\n        sd_vae.requires_grad_(False)\n        encode_func = functools.partial(sd_vae_encode, sd_vae)\n    elif 'dc-ae' in model_config.vae:\n        dc_ae = AutoencoderDC.from_pretrained(model_config.vae).to(accelerator.device, dtype=weight_dtype)\n        dc_ae.eval()\n        dc_ae.requires_grad_(False)\n        encode_func = functools.partial(dc_ae_encode, dc_ae)\n        \n    \n    \n\n    # Setup Dataloader\n    total_batch_size = (\n        data_config.dataloader.batch_size * \n        accelerator.num_processes * \n        train_config.gradient_accumulation_steps\n    )\n    global_steps = 0\n    if train_config.resume_from_checkpoint:\n        # normal read with safety check\n        if train_config.resume_from_checkpoint != \"latest\":\n            resume_from_path = os.path.basename(train_config.resume_from_checkpoint)\n        else:   # Get the most recent checkpoint\n            dirs = os.listdir(checkpoint_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None\n\n        if resume_from_path is None:\n            logger.info(\n                f\"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            train_config.resume_from_checkpoint = None\n        else:\n            global_steps = int(resume_from_path.split(\"-\")[1]) # gs not calculate the gradient_accumulation_steps\n            logger.info(f\"Resuming from steps: {global_steps}\")\n    \n    get_train_dataloader = ImagenetLoader(data_config)\n    train_len = get_train_dataloader.train_len()\n    train_config.max_train_steps = math.ceil(train_len / total_batch_size)\n    train_dataloader = get_train_dataloader.train_dataloader(\n        global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, resume_step=global_steps, \n    )\n\n    \n    # Prepare Accelerate\n    train_dataloader= accelerator.prepare(train_dataloader)    \n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}\")\n    logger.info(f\"  Dataset Length = {get_train_dataloader.train_len()}\")\n    logger.info(f\"  Instantaneous batch size per device = {data_config.dataloader.batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {train_config.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {train_config.max_train_steps}\")\n    \n\n    # Potentially load in the weights and states from a previous save\n    if train_config.resume_from_checkpoint and resume_from_path != None:\n        accelerator.print(f\"Resuming from checkpoint {resume_from_path}\")\n        accelerator.load_state(resume_from_path)\n\n        \n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(\n        range(0, train_config.max_train_steps), \n        disable = not accelerator.is_local_main_process\n    )\n    progress_bar.set_description(\"Optim Steps\")\n    progress_bar.update(global_steps)\n    \n    # prepare patch size and max sequence length\n    # make directory\n    os.makedirs(data_config.dataset.target_dir, exist_ok=True)\n    def save_data(z, y, p):\n        # p: 'datasets/imagenet1k/images/train/n01440764/n01440764_10026.JPEG'\n        # target_folder: 'target_dir/n01440764'\n        target_folder = os.path.join(data_config.dataset.target_dir, p.split('/')[-2])\n        f_name = p.split('/')[-1].split('.')[0]\n        os.makedirs(target_folder, exist_ok=True)\n        single_data = dict(latent=z.contiguous(), label=y)\n        save_file(single_data, os.path.join(target_folder, f'{f_name}.safetensors'))\n    \n    for step, batch in enumerate(train_dataloader, start=global_steps):\n        for batch_key in batch.keys():\n            if not isinstance(batch[batch_key], (list, str)):\n                batch[batch_key] = batch[batch_key].to(dtype=weight_dtype)\n            x = batch['jpg']\n            y = batch['cls']\n            p = batch['path']\n            if 'sd-vae' in model_config.vae:\n                z_ori = encode_func(x)\n                z_hflip = encode_func(hflip(x))\n                z = torch.stack([z_ori, z_hflip], dim=1)\n            elif 'dc-ae' in model_config.vae:\n                z_ori = encode_func(x)\n                z_hflip = encode_func(hflip(x))\n                z = torch.stack([z_ori, z_hflip], dim=1)\n            \n            for i in range(len(p)):\n                save_data(z[i], y[i], p[i])\n            \n        # Checks if the accelerator has performed an optimization step behind the scenes; Check gradient accumulation\n        if accelerator.sync_gradients: \n            progress_bar.update(1)\n            global_steps += 1\n                \n            if global_steps % train_config.checkpointing_steps == 0:\n                if accelerator.is_main_process:\n                    # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                    if train_config.checkpoints_total_limit is not None:\n                        checkpoints = os.listdir(checkpoint_dir)\n                        checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                        checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                        # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                        if len(checkpoints) >= train_config.checkpoints_total_limit:\n                            num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1\n                            removing_checkpoints = checkpoints[0:num_to_remove]\n\n                            logger.info(\n                                f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                            )\n                            logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                            for removing_checkpoint in removing_checkpoints:\n                                removing_checkpoint = os.path.join(checkpoint_dir, removing_checkpoint)\n                                shutil.rmtree(removing_checkpoint)\n\n                    save_path = os.path.join(checkpoint_dir, f\"checkpoint-{global_steps}\")\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n                    \n                accelerator.wait_for_everyone()\n        \n        if global_steps >= train_config.max_train_steps:\n            break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n\n        \nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)"
  },
  {
    "path": "projects/sample/sample_c2i_ddp.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\"\nSamples a large number of images from a pre-trained SiT model using DDP.\nSubsequently saves a .npz file that can be used to compute FID and other\nevaluation metrics via the ADM repo: https://github.com/openai/guided-diffusion/tree/main/evaluations\n\nFor a simple single-GPU/CPU sampling script, see sample.py.\n\"\"\"\nimport torch\nimport torch.distributed as dist\nfrom diffusers.models import AutoencoderKL, AutoencoderDC\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nimport numpy as np\nimport math\nimport functools\nimport argparse\nfrom omegaconf import OmegaConf\nfrom einops import rearrange\nfrom nit.schedulers.flow_matching.samplers_c2i import euler_sampler, euler_maruyama_sampler\nfrom nit.utils import init_from_ckpt\nfrom nit.utils.misc_utils import instantiate_from_config\nfrom nit.utils.model_utils import sd_vae_decode, dc_ae_decode\n\n\ndef create_npz_from_sample_folder(sample_dir, num=50_000):\n    \"\"\"\n    Builds a single .npz file from a folder of .png samples.\n    \"\"\"\n    samples = []\n    for i in tqdm(range(num), desc=\"Building .npz file from samples\"):\n        sample_pil = Image.open(f\"{sample_dir}/{i:06d}.png\")\n        sample_np = np.asarray(sample_pil).astype(np.uint8)\n        samples.append(sample_np)\n    samples = np.stack(samples)\n    assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)\n    npz_path = f\"{sample_dir}.npz\"\n    np.savez(npz_path, arr_0=samples)\n    print(f\"Saved .npz file to {npz_path} [shape={samples.shape}].\")\n    return npz_path\n\n\ndef main(args):\n    \"\"\"\n    Run sampling.\n    \"\"\"\n    torch.backends.cuda.matmul.allow_tf32 = args.tf32  # True: fast but may lead to some small numerical differences\n    assert torch.cuda.is_available(), \"Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage\"\n    torch.set_grad_enabled(False)\n\n    # Setup DDP:cd\n    dist.init_process_group(\"nccl\")\n    rank = dist.get_rank()\n    device = rank % torch.cuda.device_count()\n    seed = args.global_seed * dist.get_world_size() + rank\n    torch.manual_seed(seed)\n    torch.cuda.set_device(device)\n    print(f\"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.\")\n\n    # setup dtype\n    dtype = torch.bfloat16\n\n    # Load model:\n    config = OmegaConf.load(args.config)\n    model_config = config.model \n    \n    if 'dc-ae' in model_config.vae_dir:\n        dc_ae = AutoencoderDC.from_pretrained(model_config.vae_dir).to(device)\n        if args.slice_vae:\n            dc_ae.enable_slicing()\n        if args.slice_vae:\n            dc_ae.enable_slicing()\n        spatial_downsample = 32\n        decode_func = functools.partial(dc_ae_decode, dc_ae)\n    elif 'sd-vae' in model_config.vae_dir:\n        sd_vae = AutoencoderKL.from_pretrained(model_config.vae_dir).to(device)\n        if args.slice_vae:\n            sd_vae.enable_slicing()\n        if args.slice_vae:\n            sd_vae.enable_slicing()\n        spatial_downsample = 8\n        decode_func = functools.partial(sd_vae_decode, sd_vae)\n    else: raise\n    assert args.cfg_scale >= 1.0, \"In almost all cases, cfg_scale be >= 1.0\"\n    # image resolution\n    patch_size = int(model_config.network.params.patch_size)\n    latent_h = int(args.height / spatial_downsample / patch_size)\n    latent_w = int(args.width / spatial_downsample / patch_size)\n\n\n    if args.interpolation != 'no':    \n        model_config.network.params['custom_freqs'] = args.interpolation\n        model_config.network.params['max_pe_len_h'] = latent_h\n        model_config.network.params['max_pe_len_w'] = latent_w\n        model_config.network.params['decouple'] = args.decouple\n        model_config.network.params['ori_max_pe_len'] = int(args.ori_max_pe_len)\n    \n    model = instantiate_from_config(model_config.network).to(device=device, dtype=dtype)\n    init_from_ckpt(model, checkpoint_dir=args.ckpt, ignore_keys=None, verbose=True)\n    model.eval()  # important!\n    \n    if args.ag_config != None and args.ag_ckpt != None:\n        ag_config = OmegaConf.load(args.ag_config)\n        ag_model_config = ag_config.model \n        ag_model = instantiate_from_config(ag_model_config.network).to(device=device, dtype=dtype)\n        init_from_ckpt(ag_model, checkpoint_dir=args.ag_ckpt, ignore_keys=None, verbose=True)\n        ag_model.eval()  # important!\n    else:\n        ag_model = None\n    \n    \n    # Create folder to save samples:\n    train_iter = args.ckpt.split('/')[-2].split('-')[-1]\n    folder_name = f\"{train_iter}-{args.height}x{args.width}-{args.mode}-{args.num_steps}-\" \\\n                  f\"cfg-{args.cfg_scale}-low-{args.guidance_low}-high-{args.guidance_high}\"\n    if ag_model != None:\n        sample_folder_dir = f\"{args.sample_dir}/ag-{folder_name}\"\n    else:\n        sample_folder_dir = f\"{args.sample_dir}/{folder_name}\"\n    if args.interpolation != 'no':\n        sample_folder_dir += f'-{args.interpolation}'\n    if rank == 0:\n        os.makedirs(sample_folder_dir, exist_ok=True)\n        print(f\"Saving .png samples at {sample_folder_dir}\")\n    dist.barrier()\n\n    # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:\n    n = args.per_proc_batch_size\n    global_batch_size = n * dist.get_world_size()\n    # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:\n    total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)\n    if rank == 0:\n        print(f\"Total number of images that will be sampled: {total_samples}\")\n        print(f\"Model Parameters: {sum(p.numel() for p in model.parameters()):,}\")\n    assert total_samples % dist.get_world_size() == 0, \"total_samples must be divisible by world_size\"\n    samples_needed_this_gpu = int(total_samples // dist.get_world_size())\n    assert samples_needed_this_gpu % n == 0, \"samples_needed_this_gpu must be divisible by the per-GPU batch size\"\n    iterations = int(samples_needed_this_gpu // n)\n    pbar = range(iterations)\n    pbar = tqdm(pbar) if rank == 0 else pbar\n    total = 0\n    for i in pbar:\n        \n        # Sample inputs:\n        z = torch.randn(\n            (n*latent_h*latent_w, model.in_channels, patch_size, patch_size), \n            device=device, dtype=dtype\n        )\n        y = torch.randint(0, args.num_classes, (n,), device=device)\n        hw_list = torch.tensor([[latent_h, latent_w] for _ in range(n)], device=device, dtype=torch.int)\n        seqlens = hw_list[:, 0] * hw_list[:, 1]\n        cu_seqlens = torch.cat([\n            torch.tensor([0], device=hw_list.device, dtype=torch.int32), \n            torch.cumsum(seqlens, dim=0, dtype=torch.int32)\n        ])\n        \n        can_pass = True\n        for j in range(n):\n            index = j * dist.get_world_size() + rank + total\n            if not os.path.exists(f\"{sample_folder_dir}/{index:06d}.png\"):\n                can_pass = False\n        if can_pass:\n            total += global_batch_size\n            print('total: ', total)\n            continue\n\n        # Sample images:\n        sampling_kwargs = dict(\n            model=model, \n            ag_model=ag_model,\n            latents=z,\n            y=y,\n            hw_list=hw_list,\n            num_steps=args.num_steps, \n            heun=args.heun,\n            cfg_scale=args.cfg_scale,\n            guidance_low=args.guidance_low,\n            guidance_high=args.guidance_high,\n            path_type=args.path_type,\n        )\n        with torch.no_grad():\n            if args.mode == \"sde\":\n                samples = euler_maruyama_sampler(**sampling_kwargs).to(torch.float32)\n            elif args.mode == \"ode\":\n                samples = euler_sampler(**sampling_kwargs).to(torch.float32)\n            else:\n                raise NotImplementedError\n\n            samples = rearrange(samples, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=latent_h, w=latent_w)\n            samples = decode_func(samples)\n            samples = (samples + 1) / 2.\n            samples = torch.clamp(\n                255. * samples, 0, 255\n            ).permute(0, 2, 3, 1).to(\"cpu\", dtype=torch.uint8).numpy()\n\n            # Save samples to disk as individual .png files\n            for i, sample in enumerate(samples):\n                index = i * dist.get_world_size() + rank + total\n                Image.fromarray(sample).save(f\"{sample_folder_dir}/{index:06d}.png\")\n        total += global_batch_size\n\n    # Make sure all processes have finished saving their samples before attempting to convert to .npz\n    dist.barrier()\n    if rank == 0:\n        # create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)\n        print(\"Done.\")\n    dist.barrier()\n    dist.destroy_process_group()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    # seed\n    parser.add_argument(\"--global-seed\", type=int, default=0)\n\n    # precision\n    parser.add_argument(\"--tf32\", action=argparse.BooleanOptionalAction, default=True,\n                        help=\"By default, use TF32 matmuls. This massively accelerates sampling on Ampere GPUs.\")\n\n    # logging/saving:\n    parser.add_argument(\"--config\", type=str, default=None, help=\"Optional config to a SiT checkpoint.\")\n    parser.add_argument(\"--ckpt\", type=str, default=None, help=\"Optional path to a SiT checkpoint.\")\n    parser.add_argument(\"--sample-dir\", type=str, default=\"workdir/c2i/samples\")\n    parser.add_argument(\"--ag-config\", type=str, default=None)\n    parser.add_argument(\"--ag-ckpt\", type=str, default=None)\n\n\n    # model\n    parser.add_argument(\"--num-classes\", type=int, default=1000)\n    parser.add_argument(\"--height\", type=int, default=256)\n    parser.add_argument(\"--width\", type=int, default=256)\n    parser.add_argument(\"--slice_vae\", action=argparse.BooleanOptionalAction, default=False) # only for ode\n    \n    # number of samples\n    parser.add_argument(\"--per-proc-batch-size\", type=int, default=32)\n    parser.add_argument(\"--num-fid-samples\", type=int, default=50_000)\n\n    # sampling related hyperparameters\n    parser.add_argument(\"--mode\", type=str, default=\"ode\")\n    parser.add_argument(\"--cfg-scale\",  type=float, default=1.5)\n    parser.add_argument(\"--path-type\", type=str, default=\"linear\", choices=[\"linear\", \"cosine\"])\n    parser.add_argument(\"--num-steps\", type=int, default=50)\n    parser.add_argument(\"--heun\", action=argparse.BooleanOptionalAction, default=False) # only for ode\n    parser.add_argument(\"--guidance-low\", type=float, default=0.)\n    parser.add_argument(\"--guidance-high\", type=float, default=1.)\n\n    parser.add_argument(\"--interpolation\", type=str, choices=['no', 'linear', 'ntk-aware', 'ntk-by-parts', 'yarn', 'ntk-aware-pro1', 'ntk-aware-pro2', 'scale1', 'scale2'], default='no') # interpolation\n    parser.add_argument(\"--ori-max-pe-len\", default=None, type=int)\n    parser.add_argument(\"--decouple\", default=False, action=\"store_true\") # interpolation\n    \n    # will be deprecated\n    parser.add_argument(\"--legacy\", action=argparse.BooleanOptionalAction, default=False) # only for ode\n    \n\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "projects/train/packed_trainer_c2i.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport copy\nimport functools\nimport gc\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport time\nimport random\nimport shutil\nimport importlib\nimport csv\nimport numpy as np\nimport os.path as osp\nfrom pathlib import Path\nfrom typing import List, Union\nfrom packaging import version\nfrom tqdm.auto import tqdm\nfrom copy import deepcopy\nfrom omegaconf import OmegaConf\nfrom einops import rearrange, repeat\n\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport torchvision.transforms.functional as TF\nfrom torch.utils.data import default_collate, Dataset\nfrom torchvision import transforms\nfrom torchvision.transforms import Normalize\n\nimport accelerate\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\n\nimport transformers\n\nimport diffusers\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom diffusers import AutoencoderKL\n\nfrom timeit import default_timer as timer\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n\nfrom nit.schedulers.flow_matching.loss import FlowMatchingLoss\nfrom nit.data.packed_c2i_data import C2ILoader\nfrom nit.utils.misc_utils import (\n    get_obj_from_str, get_dtype, instantiate_from_config\n)\nfrom nit.utils.train_utils import (\n    update_ema, log_validation,\n)\nfrom nit.utils.gpu_memory_monitor import build_gpu_memory_monitor\n\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.18.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    # ----General Training Arguments----\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"\",\n        help=\"The config file for training.\",\n    )\n    parser.add_argument(\n        \"--project_dir\",\n        type=str,\n        default=\"t2i_linear_attention\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--seed\", \n        type=int, \n        default=None, \n        help=\"A seed for reproducible training.\"\n    )\n    args = parser.parse_args()\n    return args\n\n\n\ndef main(args):\n    project_dir = args.project_dir\n    config = OmegaConf.load(args.config)\n    model_config = config.model \n    data_config = config.data\n    train_config = config.training\n\n    config_dir = osp.join(project_dir, 'configs')\n    checkpoint_dir = osp.join(project_dir, 'checkpoints')\n    logging_dir = osp.join(project_dir, 'logs')\n    sample_dir = osp.join(project_dir, 'samples')\n\n    if getattr(train_config, 'fsdp_config', None) != None:\n        import functools\n        from torch.distributed.fsdp.fully_sharded_data_parallel import (\n            BackwardPrefetch, CPUOffload, ShardingStrategy, MixedPrecision, \n            StateDictType, FullStateDictConfig, FullOptimStateDictConfig,\n        )\n        from accelerate.utils import FullyShardedDataParallelPlugin\n        from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy\n\n        fsdp_cfg = train_config.fsdp_config\n        if train_config.mixed_precision == \"fp16\":\n            dtype = torch.float16\n        elif train_config.mixed_precision == \"bf16\":\n            dtype = torch.bfloat16\n        else:\n            dtype = torch.float32   \n        fsdp_plugin = FullyShardedDataParallelPlugin(\n            sharding_strategy = {\n                'FULL_SHARD': ShardingStrategy.FULL_SHARD,\n                'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,\n                'NO_SHARD': ShardingStrategy.NO_SHARD,\n                'HYBRID_SHARD': ShardingStrategy.HYBRID_SHARD,\n                'HYBRID_SHARD_ZERO2': ShardingStrategy._HYBRID_SHARD_ZERO2,\n            }[fsdp_cfg.sharding_strategy],\n            backward_prefetch = {\n                'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,\n                'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,\n            }[fsdp_cfg.backward_prefetch],\n            mixed_precision_policy = MixedPrecision(\n                param_dtype=dtype,\n                reduce_dtype=dtype,\n            ),\n            auto_wrap_policy = functools.partial(\n                size_based_auto_wrap_policy, min_num_params=fsdp_cfg.min_num_params\n            ),\n            cpu_offload = CPUOffload(offload_params=fsdp_cfg.cpu_offload),\n            state_dict_type = {\n                'FULL_STATE_DICT': StateDictType.FULL_STATE_DICT,\n                'LOCAL_STATE_DICT': StateDictType.LOCAL_STATE_DICT,\n                'SHARDED_STATE_DICT': StateDictType.SHARDED_STATE_DICT\n            }[fsdp_cfg.state_dict_type],\n            state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True),\n            optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),\n            limit_all_gathers = fsdp_cfg.limit_all_gathers,\n            use_orig_params = fsdp_cfg.use_orig_params,\n            sync_module_states = fsdp_cfg.sync_module_states,\n            forward_prefetch = fsdp_cfg.forward_prefetch,\n            activation_checkpointing = fsdp_cfg.activation_checkpointing,\n        )\n    else:\n        fsdp_plugin = None\n\n    accelerator_project_config = ProjectConfiguration(project_dir=project_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=train_config.gradient_accumulation_steps,\n        mixed_precision=train_config.mixed_precision,\n        log_with=train_config.tracker,\n        project_config=accelerator_project_config,\n        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes\n        fsdp_plugin=fsdp_plugin,\n    )\n    \n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        os.makedirs(project_dir, exist_ok=True)\n        os.makedirs(config_dir, exist_ok=True)\n        os.makedirs(checkpoint_dir, exist_ok=True)\n        os.makedirs(logging_dir, exist_ok=True)\n        os.makedirs(sample_dir, exist_ok=True)\n        OmegaConf.save(config=config, f=osp.join(config_dir, \"config.yaml\"))\n    \n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    if train_config.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    total_batch_size = (\n        data_config.dataloader.batch_size * \n        accelerator.num_processes * \n        train_config.gradient_accumulation_steps\n    )\n    if train_config.scale_lr:\n        learning_rate = (\n            train_config.learning_rate * \n            total_batch_size / train_config.learning_rate_base_batch_size\n        )\n    else:\n        learning_rate = train_config.learning_rate\n    \n    \n    # prepare model, dataloader, optimizer and scheduler\n    model = instantiate_from_config(model_config.network).to(device=accelerator.device)\n    model.train()\n    if model_config.use_ema:\n        ema_model = deepcopy(model)\n        ema_model.train()\n        ema_model.requires_grad_(False)\n    # Handle mixed precision and device placement\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training, copy of the weights should still be float32.\"\n    )\n    if accelerator.unwrap_model(model).dtype != torch.float32:\n        raise ValueError(\n            f\"Controlnet loaded as datatype {accelerator.unwrap_model(model).dtype}. {low_precision_error_string}\"\n        )\n    \n    if accelerator.is_main_process:\n        total_params = 0\n        trainable_params = 0\n        projector_params = 0\n        for name, param in model.named_parameters():\n            print(name, param.requires_grad)\n            total_params += param.numel()  # Total number of elements in the parameter\n            if param.requires_grad:          # Check if the parameter is trainable\n                trainable_params += param.numel()\n            if 'projector' in name:\n                projector_params += param.numel()\n        print(trainable_params, total_params, total_params-projector_params, trainable_params/total_params)\n    \n    # Optimizer creation\n    target_optimizer = train_config.optimizer.get('target', 'torch.optim.AdamW')\n    optimizer = get_obj_from_str(target_optimizer)(\n        model.parameters(), lr=learning_rate, \n        **train_config.optimizer.get(\"params\", dict())\n    )\n\n    # Dataset creation and data processing\n    # Here, we compute not just the text embeddings but also the additional embeddings\n    global_steps = 0\n    if train_config.resume_from_checkpoint:\n        # normal read with safety check\n        if train_config.resume_from_checkpoint != \"latest\":\n            resume_from_path = os.path.basename(train_config.resume_from_checkpoint)\n        else:   # Get the most recent checkpoint\n            dirs = os.listdir(checkpoint_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            resume_from_path = osp.join(checkpoint_dir, dirs[-1]) if len(dirs) > 0 else None\n\n        if resume_from_path is None:\n            logger.info(\n                f\"Checkpoint '{train_config.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            train_config.resume_from_checkpoint = None\n        else:\n            global_steps = int(resume_from_path.split(\"-\")[1]) # gs not calculate the gradient_accumulation_steps\n            logger.info(f\"Resuming from steps: {global_steps}\")\n    \n    get_train_dataloader = C2ILoader(data_config)\n    train_dataloader = get_train_dataloader.train_dataloader(\n        rank=accelerator.process_index, world_size=accelerator.num_processes, \n        global_batch_size=total_batch_size, max_steps=train_config.max_train_steps, \n        resume_steps=global_steps, seed=args.seed\n    )\n\n    # LR Scheduler creation\n    # Scheduler and math around the number of training steps.\n    lr_scheduler = get_scheduler(\n        train_config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=train_config.lr_warmup_steps,\n        num_training_steps=train_config.max_train_steps,\n    )\n\n    # Prepare for training\n    # Prepare everything with our `accelerator`.\n    if model_config.use_ema:\n        ema_model, model, optimizer, lr_scheduler = accelerator.prepare(\n            ema_model, model, optimizer, lr_scheduler\n        )\n    else:\n        model, optimizer, lr_scheduler = accelerator.prepare(\n            model, optimizer, lr_scheduler\n        )\n\n    # transport \n    loss_fn = FlowMatchingLoss(**OmegaConf.to_container(model_config.transport))\n    if model_config.enc_type == 'radio':\n        from nit.models.nvidia_radio.hubconf import radio_model\n        encoder = radio_model(version=model_config.enc_dir, progress=True, support_packing=True)\n        encoder.to(device=accelerator.device).eval()\n        encoder.requires_grad_(False)\n    \n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process and getattr(train_config, 'tracker', 'wandb') != None:\n        tracker_project_name = project_dir.split('/')[-1]\n        # accelerator.init_trackers(\"mcga\", config=config, init_kwargs=train_config.tracker_kwargs)\n        accelerator.init_trackers(tracker_project_name, config=config, init_kwargs=train_config.tracker_kwargs)\n\n    \n    # initialize GPU memory monitor before applying parallelisms to the model\n    gpu_memory_monitor = build_gpu_memory_monitor(logger)\n    gpu_mem_stats = gpu_memory_monitor.get_peak_stats()\n\n    # 15. Train!\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {get_train_dataloader.train_len()/data_config.dataloader.batch_size}\")\n    logger.info(f\"  Dataset Length = {get_train_dataloader.train_len()}\")\n    logger.info(f\"  Instantaneous batch size per device = {data_config.dataloader.batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {train_config.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {train_config.max_train_steps}\")\n    logger.info(\n        \"  GPU memory usage for model: \"\n        f\"{gpu_mem_stats.max_reserved_gib:.2f}GiB\"\n        f\"({gpu_mem_stats.max_reserved_pct:.2f}%)\"\n    )\n\n    gpu_memory_monitor.reset_peak_stats()\n    data_loading_times = []\n    feat_enc_times = []\n    \n    # Potentially load in the weights and states from a previous save\n    if train_config.resume_from_checkpoint and resume_from_path != None:\n        accelerator.print(f\"Resuming from checkpoint {resume_from_path}\")\n        accelerator.load_state(resume_from_path)\n\n    progress_bar = tqdm(\n        range(0, train_config.max_train_steps),\n        initial=global_steps,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_main_process,\n    )\n\n    \n    for batch in train_dataloader:\n        time_last_log = timer()\n        data_load_start = timer()\n        # load dataset from batch\n        batch_image = [image.to(accelerator.device) for image in batch['image']]\n        batch_label = batch['label'].squeeze(0).to(accelerator.device, torch.int)\n        packed_latent = batch['latent'].squeeze(0).to(accelerator.device)\n        noises = torch.randn_like(packed_latent)\n        hw_list = batch['hw_list'].squeeze(0).to(torch.int)\n        batch_size = hw_list.shape[0]\n        \n        dropout_prob = model_config.network.params.class_dropout_prob\n        num_classes = model_config.network.params.num_classes\n        if dropout_prob > 0:\n            drop_ids = torch.rand(batch_label.shape[0], device=accelerator.device) < dropout_prob\n            batch_label = torch.where(drop_ids, num_classes, batch_label)\n        data_loading_times.append(timer() - data_load_start)\n                \n        feat_enc_start = timer()\n        zs = []\n        if model_config.enc_type == 'radio':\n            with torch.no_grad(), accelerator.autocast():\n                raw_images = [(image.unsqueeze(0)+1.0)/2.0 for image in batch_image]\n                _, z = encoder.forward_pack(raw_images)\n                zs.append(z)\n        feat_enc_times.append(timer() - feat_enc_start)\n\n        with accelerator.accumulate(model):\n            # forward and calculate loss\n            model_kwargs = dict(y=batch_label, hw_list=hw_list)\n            fm_loss, proj_loss = loss_fn(model, batch_size, packed_latent, noises, model_kwargs, use_dir_loss=True, zs=zs)\n            loss = fm_loss + model_config.proj_coeff * proj_loss\n            accelerator.backward(loss)\n            if accelerator.sync_gradients and train_config.max_grad_norm > 0:\n                all_norm = accelerator.clip_grad_norm_(model.parameters(), train_config.max_grad_norm)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad(set_to_none=True)\n\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if accelerator.sync_gradients:\n            # 20.4.15. Make EMA update to target student model parameters\n            if model_config.use_ema:\n                update_ema(ema_model, model, model_config.ema_decay)\n            global_steps += 1\n            time_delta = timer() - time_last_log\n            sps = batch_size / time_delta\n            time_data_loading = np.mean(data_loading_times)\n            time_feat_enc = np.mean(feat_enc_times)\n            time_data_loading_pct = 100 * time_data_loading / time_delta\n            time_feat_enc_pct = 100 * time_feat_enc / time_delta\n            gpu_mem_stats = gpu_memory_monitor.get_peak_stats()\n            \n            if global_steps % train_config.checkpointing_steps == 0:\n                # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                if accelerator.is_main_process and train_config.checkpoints_total_limit is not None:\n                    checkpoints = os.listdir(checkpoint_dir)\n                    checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                    checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                    # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                    if len(checkpoints) >= train_config.checkpoints_total_limit:\n                        num_to_remove = len(checkpoints) - train_config.checkpoints_total_limit + 1\n                        removing_checkpoints = checkpoints[0:num_to_remove]\n\n                        logger.info(\n                            f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                        )\n                        logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                        for removing_checkpoint in removing_checkpoints:\n                            removing_checkpoint = osp.join(checkpoint_dir, removing_checkpoint)\n                            try:\n                                shutil.rmtree(removing_checkpoint)\n                            except:\n                                pass\n                save_path = osp.join(checkpoint_dir, f\"checkpoint-{global_steps}\")\n                if accelerator.is_main_process:\n                    os.makedirs(save_path, exist_ok=True)\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n\n                \n\n                if global_steps in train_config.checkpoint_list:\n                    save_path = os.path.join(checkpoint_dir, f\"save-checkpoint-{global_steps}\")\n                    if accelerator.is_main_process:\n                        os.makedirs(save_path, exist_ok=True)\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n                time.sleep(10)\n                torch.cuda.empty_cache()\n            \n                if global_steps % train_config.validation_steps == 0:\n                    log_validation(model)\n            logs = {\n                # loss and lr\n                \"loss_denoising\": fm_loss.detach().item(), \n                \"loss_projector\": proj_loss.detach().item(), \n                \"lr\": lr_scheduler.get_last_lr()[0],\n                # time and status\n                \"sps\": sps,\n                \"data_loading(s)\": time_data_loading,\n                \"data_loading(%)\": time_data_loading_pct,\n                \"time_feat_enc(s)\": time_feat_enc,\n                \"time_feat_enc(%)\": time_feat_enc_pct,\n                \"memory_max_active(GiB)\": gpu_mem_stats.max_active_gib,\n                \"memory_max_active(%)\": gpu_mem_stats.max_active_pct,\n                \"memory_max_reserved(GiB)\": gpu_mem_stats.max_reserved_gib,\n                \"memory_max_reserved(%)\": gpu_mem_stats.max_reserved_pct,\n                \"memory_num_alloc_retries\": gpu_mem_stats.num_alloc_retries,\n                \"memory_num_ooms\": gpu_mem_stats.num_ooms\n            }\n            if accelerator.sync_gradients and train_config.max_grad_norm > 0:\n                logs.update({'grad_norm': all_norm.item()})\n            progress_bar.set_postfix(**logs)\n            progress_bar.update(1)\n            accelerator.log(logs, step=global_steps)\n        if global_steps >= train_config.max_train_steps:\n            break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "diffusers>=0.30.1 #git+https://github.com/huggingface/diffusers.git@main#egg=diffusers is suggested\ntransformers>=4.44.2  # The development team is working on version 4.44.2\naccelerate>=0.33.0 #git+https://github.com/huggingface/accelerate.git@main#egg=accelerate is suggested\nsentencepiece>=0.2.0 # T5 used\nnumpy==1.26.0\nstreamlit>=1.38.0 # For streamlit web demo\nimageio==2.34.2 # For diffusers inference export video\nimageio-ffmpeg==0.5.1 # For diffusers inference export video\nmoviepy==1.0.3 # For export video\npillow==9.5.0\ntimm\nsafetensors\neinops\ntriton\ntorchdiffeq"
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_256x256.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/preprocess/image_latent_c2i.py \\\n    --config configs/preprocess/imagenet1k_256x256.yaml \\\n    --project_dir workdir/preprocess/imagenet1k_256x256 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_512x512.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/preprocess/image_latent_c2i.py \\\n    --config configs/preprocess/imagenet1k_512x512.yaml \\\n    --project_dir workdir/preprocess/imagenet1k_512x512 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/preprocess/preorocess_in1k_native_resolution.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=$((30000 + $RANDOM % 21000))\n\nCMD=\" \\\n    projects/preprocess/image_nr_latent_c2i.py \\\n    --config configs/preprocess/imagenet1k_native_resolution.yaml \\\n    --project_dir workdir/preprocess/imagenet1k_native_resolution \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/sample/sample_256x256.sh",
    "content": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \\\n  --ckpt checkpoints/nit_xl_model_1000K.safetensors \\\n  --sample-dir ./samples \\\n  --height 256 \\\n  --width 256 \\\n  --per-proc-batch-size 32 \\\n  --mode sde \\\n  --num-steps 250 \\\n  --cfg-scale 2.25 \\\n  --guidance-low 0.0 \\\n  --guidance-high 0.7 \\\n  --slice_vae \\"
  },
  {
    "path": "scripts/sample/sample_512x512.sh",
    "content": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \\\n  --ckpt checkpoints/nit_xl_model_1000K.safetensors \\\n  --sample-dir ./samples \\\n  --height 512 \\\n  --width 512 \\\n  --per-proc-batch-size 32 \\\n  --mode sde \\\n  --num-steps 250 \\\n  --cfg-scale 2.05 \\\n  --guidance-low 0.0 \\\n  --guidance-high 0.7 \\\n  --slice_vae \\"
  },
  {
    "path": "scripts/sample/sample_768x768.sh",
    "content": "torchrun \\\n  --nnodes 1 \\\n  --nproc_per_node 8 \\\n  projects/sample/sample_c2i_ddp.py \\\n  --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \\\n  --ckpt checkpoints/nit_xl_model_1000K.safetensors \\\n  --sample-dir ./samples \\\n  --height 768 \\\n  --width 768 \\\n  --per-proc-batch-size 32 \\\n  --mode ode \\\n  --num-steps 50 \\\n  --cfg-scale 3.0 \\\n  --guidance-low 0.0 \\\n  --guidance-high 0.7 \\\n  --slice_vae \\"
  },
  {
    "path": "scripts/train/train_b_model.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_b_pack_merge_radio_65536\nCMD=\" \\\n    projects/train/packed_trainer_c2i.py \\\n    --config configs/c2i/nit_b_pack_merge_radio_65536.yaml \\\n    --project_dir workdir/c2i/nit_b_pack_merge_radio_65536 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/train/train_l_model.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_l_pack_merge_radio_16384\nCMD=\" \\\n    projects/train/packed_trainer_c2i.py \\\n    --config configs/c2i/nit_l_pack_merge_radio_16384.yaml \\\n    --project_dir workdir/c2i/nit_l_pack_merge_radio_16384 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/train/train_s_model.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=2\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_s_pack_merge_radio_65536\nCMD=\" \\\n    projects/train/packed_trainer_c2i.py \\\n    --config configs/c2i/nit_s_pack_merge_radio_65536.yaml \\\n    --project_dir workdir/c2i/nit_s_pack_merge_radio_65536 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/train/train_xl_model.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_xl_pack_merge_radio_16384\nCMD=\" \\\n    projects/train/packed_trainer_c2i.py \\\n    --config configs/c2i/nit_xl_pack_merge_radio_16384.yaml \\\n    --project_dir workdir/c2i/nit_xl_pack_merge_radio_16384 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "scripts/train/train_xxl_model.sh",
    "content": "NNODES=1\nGPUS_PER_NODE=8\nMASTER_ADDR=\"localhost\"\nexport MASTER_PORT=60563\nmkdir -p workdir/c2i/nit_xxl_pack_merge_radio_8192\nCMD=\" \\\n    projects/train/packed_trainer_c2i.py \\\n    --config configs/c2i/nit_xxl_pack_merge_radio_8192.yaml \\\n    --project_dir workdir/c2i/nit_xxl_pack_merge_radio_8192 \\\n    --seed 0 \\\n    \"\nTORCHLAUNCHER=\"torchrun \\\n    --nnodes $NNODES \\\n    --nproc_per_node $GPUS_PER_NODE \\\n    --rdzv_id $RANDOM \\\n    --rdzv_backend c10d \\\n    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \\\n    \"\nbash -c \"$TORCHLAUNCHER $CMD\""
  },
  {
    "path": "setup.py",
    "content": "from setuptools import find_packages, setup\n\nsetup(\n    name='nit',\n    version='0.0.1',\n    description='',\n    packages=find_packages(),\n    install_requires=[\n        'torch',\n        'numpy',\n    ],\n)"
  },
  {
    "path": "tools/download_dataset_256x256.sh",
    "content": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-256x256\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-256x256\"\nfiles=(\n    \"n01440764_n02097298.zip\"\n    \"n02097474_n02667093.zip\"\n    \"n02669723_n03530642.zip\"\n    \"n03532672_n04239074.zip\"\n    \"n04243546_n15075141.zip\"\n)\nfor file in \"${files[@]}\"; do\n    echo \"download $file ...\"\n    wget -c \"$base_url/$file\" -O \"$target_dir/$file\"\n    echo \"download $file finished\"\n    echo \"start unzip $file ...\"\n    unzip \"$target_dir/$file\" -d \"$target_dir\"\n    echo \"unzip $file finished\"\n    rm \"$target_dir/$file\"\n    echo\ndone\necho \"Successfully download all the sampler-meta\"\n\n"
  },
  {
    "path": "tools/download_dataset_512x512.sh",
    "content": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-512x512\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-512x512\"\nfiles=(\n    \"n01440764_n01697457.zip\"\n    \"n01698640_n01855672.zip\"\n    \"n01860187_n02074367.zip\"\n    \"n02077923_n02097298.zip\"\n    \"n02097474_n02110063.zip\"\n    \"n02110185_n02138441.zip\"\n    \"n02165105_n02415577.zip\"\n    \"n02417914_n02667093.zip\"\n    \"n02669723_n02859443.zip\"\n    \"n02860847_n03041632.zip\"\n    \"n03042490_n03291819.zip\"\n    \"n03297495_n03530642.zip\"\n    \"n03532672_n03743016.zip\"\n    \"n03759954_n03884397.zip\"\n    \"n03887697_n04033901.zip\"\n    \"n04033995_n04239074.zip\"\n    \"n04243546_n04398044.zip\"\n    \"n04399382_n04560804.zip\"\n    \"n04562935_n07745940.zip\"\n    \"n07747607_n15075141.zip\"\n)\nfor file in \"${files[@]}\"; do\n    echo \"download $file ...\"\n    wget -c \"$base_url/$file\" -O \"$target_dir/$file\"\n    echo \"download $file finished\"\n    echo \"start unzip $file ...\"\n    unzip \"$target_dir/$file\" -d \"$target_dir\"\n    echo \"unzip $file finished\"\n    rm \"$target_dir/$file\"\n    echo\ndone\necho \"Successfully download all the sampler-meta\"\n\n"
  },
  {
    "path": "tools/download_dataset_data_meta.sh",
    "content": "target_dir=\"datasets/imagenet1k/data_meta\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/data_meta\"\nfiles=(\n    \"dc-ae-f32c32-sana-1.1-diffusers_256x256_meta.jsonl\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_512x512_meta.jsonl\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_nr_meta.jsonl\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl\"\n)\nfor file in \"${files[@]}\"; do\n    echo \"download $file ...\"\n    wget -c \"$base_url/$file\" -O \"$target_dir/$file\"\n    echo \"download $file finished\"\n    echo\ndone\necho \"Successfully download all the data-meta\"\n\n"
  },
  {
    "path": "tools/download_dataset_native_resolution.sh",
    "content": "target_dir=\"datasets/imagenet1k/dc-ae-f32c32-sana-1.1-diffusers-native-resolution\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/dc-ae-f32c32-sana-1.1-diffusers-native-resolution\"\nfiles=(\n    \"n01440764_n01855672.zip\"\n    \"n01860187_n02097298.zip\"\n    \"n02097474_n02138441.zip\"\n    \"n02165105_n02667093.zip\"\n    \"n02669723_n03041632.zip\"\n    \"n03042490_n03530642.zip\"\n    \"n03532672_n03884397.zip\"\n    \"n03887697_n04239074.zip\"\n    \"n04243546_n04560804.zip\"\n    \"n04562935_n15075141.zip\"\n)\nfor file in \"${files[@]}\"; do\n    echo \"download $file ...\"\n    wget -c \"$base_url/$file\" -O \"$target_dir/$file\"\n    echo \"download $file finished\"\n    echo \"start unzip $file ...\"\n    unzip \"$target_dir/$file\" -d \"$target_dir\"\n    echo \"unzip $file finished\"\n    rm \"$target_dir/$file\"\n    echo\ndone\necho \"Successfully download all the sampler-meta\"\n\n"
  },
  {
    "path": "tools/download_dataset_sampler_meta.sh",
    "content": "target_dir=\"datasets/imagenet1k/sampler_meta\"\nmkdir -p $target_dir\nbase_url=\"https://huggingface.co/datasets/GoodEnough/NiT-Preprocessed-ImageNet1K/resolve/main/sampler_meta\"\nfiles=(\n    \"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_8192.json\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_16384.json\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_32768.json\"\n    \"dc-ae-f32c32-sana-1.1-diffusers_merge_LPFHP_65536.json\"\n)\nfor file in \"${files[@]}\"; do\n    echo \"download $file ...\"\n    wget -c \"$base_url/$file\" -O \"$target_dir/$file\"\n    echo \"download $file finished\"\n    echo\ndone\necho \"Successfully download all the sampler-meta\"\n\n"
  },
  {
    "path": "tools/pack_dataset.py",
    "content": "import json\nfrom nit.data.pack import pack_dataset\nimport argparse\n\n\n\ndef create_pack(data_meta, max_seq_len, algorithm, split):  \n    max_seq_per_pack = max_seq_len\n    with open(data_meta, 'r') as fp:\n        ori_dataset = [json.loads(line) for i, line in enumerate(fp)]\n    dataset_seq_lens = []\n    dataset_seq_idxs = []\n    for idx, data in enumerate(ori_dataset):\n        seq_len = int(data['latent_h']*data['latent_w'])   # patch_size=1\n        dataset_seq_lens.append(seq_len)\n        dataset_seq_idxs.append(idx)\n    total_length = len(ori_dataset)\n\n    run_length = int(total_length / split)\n    all_packed_indices = []\n    for i in range(split):\n        seq_lens = dataset_seq_lens[i*run_length: (i+1)*run_length]\n        seq_idxs = dataset_seq_idxs[i*run_length: (i+1)*run_length]\n        packed_indices = pack_dataset(\n            algorithm, max_seq_len, max_seq_per_pack, seq_lens, seq_idxs\n        ) \n        all_packed_indices.extend(packed_indices)\n    \n    sampler_json_name = data_meta.split('/')[-1].replace('_meta.jsonl', '')\n    sampler_json_name = f\"{sampler_json_name}_{algorithm}_{max_seq_len}.json\"\n    with open(f'datasets/imagenet1k/sampler_meta/{sampler_json_name}', 'w') as fp:\n        json.dump(all_packed_indices, fp, indent=4)\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    # seed\n    parser.add_argument(\"--data-meta\", type=str, default='datasets/imagenet1k/data_meta/dc-ae-f32c32-sana-1.1-diffusers_merge_meta.jsonl')\n    parser.add_argument(\"--max-seq-len\", type=int, default=16384)\n    parser.add_argument(\"--algorithm\", type=str, default='LPFHP')\n    parser.add_argument(\"--split\", type=int, default=1)\n    args = parser.parse_args()\n    create_pack(\n        data_meta=args.data_meta, max_seq_len=args.max_seq_len,\n        algorithm=args.algorithm, split=args.split\n    )\n"
  }
]