Full Code of DAMO-NLP-SG/Inf-CLIP for AI

main d9f2833b3753 cached
152 files
471.8 KB
130.5k tokens
447 symbols
1 requests
Download .txt
Showing preview only (512K chars total). Download the full file or copy to clipboard to get everything.
Repository: DAMO-NLP-SG/Inf-CLIP
Branch: main
Commit: d9f2833b3753
Files: 152
Total size: 471.8 KB

Directory structure:
gitextract_u47kktxp/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── inf_cl/
│   ├── __init__.py
│   ├── flash.py
│   └── ring.py
├── inf_clip/
│   ├── __init__.py
│   ├── constants.py
│   ├── factory.py
│   ├── model_configs/
│   │   ├── EVA01-g-14-plus.json
│   │   ├── EVA01-g-14.json
│   │   ├── EVA02-B-16.json
│   │   ├── EVA02-E-14-plus.json
│   │   ├── EVA02-E-14.json
│   │   ├── EVA02-L-14-336.json
│   │   ├── EVA02-L-14.json
│   │   ├── LiT-B-16.json
│   │   ├── LiT-B-32.json
│   │   ├── LiT-L-16.json
│   │   ├── MobileCLIP-B.json
│   │   ├── MobileCLIP-S1.json
│   │   ├── MobileCLIP-S2.json
│   │   ├── RN101-quickgelu.json
│   │   ├── RN101.json
│   │   ├── RN50-quickgelu.json
│   │   ├── RN50.json
│   │   ├── RN50x16.json
│   │   ├── RN50x4.json
│   │   ├── RN50x64.json
│   │   ├── ViT-B-16-SigLIP-256.json
│   │   ├── ViT-B-16-SigLIP-384.json
│   │   ├── ViT-B-16-SigLIP-512.json
│   │   ├── ViT-B-16-SigLIP-i18n-256.json
│   │   ├── ViT-B-16-SigLIP.json
│   │   ├── ViT-B-16-plus-240.json
│   │   ├── ViT-B-16-plus.json
│   │   ├── ViT-B-16-quickgelu.json
│   │   ├── ViT-B-16.json
│   │   ├── ViT-B-32-256.json
│   │   ├── ViT-B-32-plus-256.json
│   │   ├── ViT-B-32-quickgelu.json
│   │   ├── ViT-B-32.json
│   │   ├── ViT-H-14-378-quickgelu.json
│   │   ├── ViT-H-14-CLIPA-336.json
│   │   ├── ViT-H-14-CLIPA.json
│   │   ├── ViT-H-14-quickgelu.json
│   │   ├── ViT-H-14.json
│   │   ├── ViT-H-16.json
│   │   ├── ViT-L-14-280.json
│   │   ├── ViT-L-14-336.json
│   │   ├── ViT-L-14-CLIPA-336.json
│   │   ├── ViT-L-14-CLIPA.json
│   │   ├── ViT-L-14-quickgelu.json
│   │   ├── ViT-L-14.json
│   │   ├── ViT-L-16-320.json
│   │   ├── ViT-L-16-SigLIP-256.json
│   │   ├── ViT-L-16-SigLIP-384.json
│   │   ├── ViT-L-16.json
│   │   ├── ViT-M-16-alt.json
│   │   ├── ViT-M-16.json
│   │   ├── ViT-M-32-alt.json
│   │   ├── ViT-M-32.json
│   │   ├── ViT-S-16-alt.json
│   │   ├── ViT-S-16.json
│   │   ├── ViT-S-32-alt.json
│   │   ├── ViT-S-32.json
│   │   ├── ViT-SO400M-14-SigLIP-384.json
│   │   ├── ViT-SO400M-14-SigLIP.json
│   │   ├── ViT-bigG-14-CLIPA-336.json
│   │   ├── ViT-bigG-14-CLIPA.json
│   │   ├── ViT-bigG-14.json
│   │   ├── ViT-e-14.json
│   │   ├── ViT-g-14.json
│   │   ├── ViTamin-B-LTT.json
│   │   ├── ViTamin-B.json
│   │   ├── ViTamin-L-256.json
│   │   ├── ViTamin-L-336.json
│   │   ├── ViTamin-L.json
│   │   ├── ViTamin-L2-256.json
│   │   ├── ViTamin-L2-336.json
│   │   ├── ViTamin-L2.json
│   │   ├── ViTamin-S-LTT.json
│   │   ├── ViTamin-S.json
│   │   ├── ViTamin-XL-256.json
│   │   ├── ViTamin-XL-336.json
│   │   ├── ViTamin-XL-384.json
│   │   ├── coca_ViT-B-32.json
│   │   ├── coca_ViT-L-14.json
│   │   ├── coca_base.json
│   │   ├── coca_roberta-ViT-B-32.json
│   │   ├── convnext_base.json
│   │   ├── convnext_base_w.json
│   │   ├── convnext_base_w_320.json
│   │   ├── convnext_large.json
│   │   ├── convnext_large_d.json
│   │   ├── convnext_large_d_320.json
│   │   ├── convnext_small.json
│   │   ├── convnext_tiny.json
│   │   ├── convnext_xlarge.json
│   │   ├── convnext_xxlarge.json
│   │   ├── convnext_xxlarge_320.json
│   │   ├── mt5-base-ViT-B-32.json
│   │   ├── mt5-xl-ViT-H-14.json
│   │   ├── nllb-clip-base-siglip.json
│   │   ├── nllb-clip-base.json
│   │   ├── nllb-clip-large-siglip.json
│   │   ├── nllb-clip-large.json
│   │   ├── roberta-ViT-B-32.json
│   │   ├── swin_base_patch4_window7_224.json
│   │   ├── vit_medium_patch16_gap_256.json
│   │   ├── vit_relpos_medium_patch16_cls_224.json
│   │   ├── xlm-roberta-base-ViT-B-32.json
│   │   └── xlm-roberta-large-ViT-H-14.json
│   ├── models/
│   │   ├── clip_arch.py
│   │   ├── coca_arch.py
│   │   ├── hf_configs.py
│   │   ├── hf_model.py
│   │   ├── lit_arch.py
│   │   ├── loss.py
│   │   ├── modified_resnet.py
│   │   ├── pos_embed.py
│   │   ├── timm_model.py
│   │   ├── tokenizer.py
│   │   ├── transform.py
│   │   └── transformer.py
│   ├── openai.py
│   ├── pretrained.py
│   ├── train/
│   │   ├── data.py
│   │   ├── engine.py
│   │   ├── main.py
│   │   ├── optims.py
│   │   ├── params.py
│   │   └── utils.py
│   ├── utils.py
│   ├── zero_shot_classifier.py
│   └── zero_shot_metadata.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│   ├── benchmarks_eval.sh
│   ├── cc12m/
│   │   ├── clip_vit-b-32_bs32k.sh
│   │   ├── lit_vit-b-16_bs32k.sh
│   │   └── lit_vit-b-32_bs32k.sh
│   ├── cc3m/
│   │   ├── clip_r50_bs4k.sh
│   │   ├── clip_vit-b-32_bs16k.sh
│   │   └── lit_vit-b-32_bs16k.sh
│   ├── imagenet_eval.sh
│   └── laion400m/
│       ├── clip_vit-b-32_bs256k.sh
│       ├── lit_vit-b-16_bs256k.sh
│       ├── lit_vit-b-32_bs256k.sh
│       └── lit_vit-l-16_bs256k.sh
└── tests/
    └── example.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitattributes
================================================
*.py linguist-language=python
*.ipynb linguist-documentation


================================================
FILE: .gitignore
================================================
**/logs/
**/wandb/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
sync.sh
gpu1sync.sh
.idea
*.pdf
**/._*
**/*DS_*
**.jsonl
src/sbatch
src/misc
.vscode
src/debug
core.*

# Allow
!src/evaluation/misc/results_dbs/*

# log dirs
/work_dirs*/
/datasets/

# oss logs
/.ossutil*
/ossutil*

================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
<p align="center">
    <img src="https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad" width="150" style="margin-bottom: 0.2;"/>
<p>

<h3 align="center"><a href="https://arxiv.org/abs/2410.17243">
Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss</a></h3>
<h5 align="center"> If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏 </h2>

<h5 align="center">

[![arXiv](https://img.shields.io/badge/Arxiv-2410.17243-AD1C18.svg?logo=arXiv)](https://arxiv.org/abs/2410.17243)
[![hf_paper](https://img.shields.io/badge/🤗-Paper%20In%20HF-red.svg)](https://huggingface.co/papers/2410.17243)
[![PyPI](https://img.shields.io/badge/PyPI-Inf--CL-9C276A.svg)](https://pypi.org/project/inf-cl) <br>
[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/LICENSE)
[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FDAMO-NLP-SG%2FInf-CLIP&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com)
[![GitHub issues](https://img.shields.io/github/issues/DAMO-NLP-SG/Inf-CLIP?color=critical&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aopen+is%3Aissue)
[![GitHub closed issues](https://img.shields.io/github/issues-closed/DAMO-NLP-SG/Inf-CLIP?color=success&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aissue+is%3Aclosed)  <br>
[![zhihu](https://img.shields.io/badge/-知乎-000000?logo=zhihu&logoColor=0084FF)](https://zhuanlan.zhihu.com/p/1681887214)
[![Twitter](https://img.shields.io/badge/-Twitter-black?logo=twitter&logoColor=1D9BF0)](https://x.com/lixin4ever/status/1849669129613226457) <br>

</h5>

<div align="center"><img src="https://github.com/user-attachments/assets/2c19838b-43d8-4145-b28c-903f3d76f8ab" width="800" /></div>

<details open><summary>💡 Some other multimodal foundation model projects from our team may interest you ✨. </summary><p>
<!--  may -->

> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922) <br>
> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing <br>
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VCD)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VCD.svg?style=social)](https://github.com/DAMO-NLP-SG/VCD)  [![arXiv](https://img.shields.io/badge/Arxiv-2311.16922-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.16922) <br>

> [**VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs**](https://github.com/DAMO-NLP-SG/VideoLLaMA2) <br>
> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing <br>
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VideoLLaMA2)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA2.svg?style=social)](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [![arXiv](https://img.shields.io/badge/Arxiv-2406.07476-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.07476) <br>

> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787) <br>
> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing <br>
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/CMM)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/CMM.svg?style=social)](https://github.com/DAMO-NLP-SG/CMM)  [![arXiv](https://img.shields.io/badge/Arxiv-2410.12787-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.12787) <br>

</p></details>

## 📰 News
* **[2024.10.18]**  Release training and evaluation codes of Inf-CLIP.

<div align="center"><img src="https://github.com/user-attachments/assets/11c5cc32-aac2-497d-bbc1-33e065a71be0" width="800" /></div>

## 🛠️ Requirements and Installation

Basic Dependencies:
* Python >= 3.8
* Pytorch >= 2.0.0
* CUDA Version >= 11.8

[Remote] Install Inf-CL:
```bash
# remote installing
pip install inf_cl -i https://pypi.org/simple
```

[Local] Install Inf-CL:
```bash
pip install -e .
```

Install required packages:
```bash
git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt
```

## ⭐ Features

`inf_cl` is the triton implementation of Inf-CL loss:
* [x] [Ring-CL (inf_cl/ring.py#L238)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L238)
* [x] [Inf-CL  (inf_cl/ring.py#L251)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L251)

`inf_clip` is the CLIP training codebase with Inf-CL loss and other training features:
- [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](https://github.com/DAMO-NLP-SG/Inf-CLIP/inf_clip_train/train.py#L180)
- [x] [Gradient Cache (inf_clip/train/train.py#L292)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip_train/train.py#L292)


## 🔑 Usage

A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:
```
torchrun --nproc_per_node 2 tests/example.py
```

```python
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from inf_cl import cal_inf_loss


def create_cl_tensors(rank, world_size):
    # Parameters
    dtype = torch.float32
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)

    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
    l = l.requires_grad_() # Logit scale

    return q, k, l


if __name__ == "__main__":
    # Assume that the distributed environment has been initialized
    dist.init_process_group("nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.cuda.set_device(rank)

    # Exampled by Image-Text Contrastive Learning, q is the global image features, 
    # k is the text features, and l is the logit scale.
    q, k, l = create_cl_tensors(rank, world_size)

    # labels are diagonal elements by default. 
    # labels = torch.arange(q.shape[0])
    loss = cal_inf_loss(q, k, scale=l.exp())

    print(loss)

```

## 🚀 Main Results

### Memory Cost
<p><img src="https://github.com/user-attachments/assets/05dd3fea-0a93-4716-b321-0a94965e1fbe" width="800" "/></p>

\* denotes adopting "data offload" strategy. 

### Max Supported Batch Size
<p><img src="https://github.com/user-attachments/assets/eb38fb90-3b7e-4696-b078-b7766893f758" width="800" "/></p>

### Speed
<p><img src="https://github.com/user-attachments/assets/da72e99b-508b-450a-b12e-401d4991291a" width="800" "/></p>

### Batch Size Scaling
<p><img src="https://github.com/user-attachments/assets/5b55fa98-6558-4509-9b66-e290ecf77b41" width="800" "/></p>

Training with larger data scale needs larger batch size.

## 🗝️ Training & Evaluation

### Quick Start

To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.

1. Training Data Structure:
```bash
Inf-CLIP
├── datasets
│   ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
|   |   ├── 0000.tar
|   |   ├── 0001.tar
|   |   ├── ...
|   |   └── 0301.tar
│   ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
|   |   ├── 0000.tar
|   |   ├── 0001.tar
|   |   ├── ...
|   |   └── 1044.tar
│   ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
|   |   ├── 00000.tar
|   |   ├── 00001.tar
|   |   ├── ...
|   |   └── 41407.tar
```
2. Command:
```bash
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
```
3. Evaluation Data Structure:
```bash
Inf-CLIP
├── datasets
│   ├── imagenet-1k/ # download val_images.tar.gz of imagenet from https://huggingface.co/datasets/ILSVRC/imagenet-1k/tree/main/data
|   |   └── val/ # python datasets/reformat_imagenet.py
|   |   |   ├── n01440764
|   |   |   ├── n01443537
|   |   |   ├── ...
|   |   |   └── n15075141
│   ├── clip-benchmark/ # bash datasets/benchmarks_download.sh
|   |   ├── wds_mscoco_captions
|   |   ├── wds_flickr8k
|   |   ├── wds_flickr30k
|   |   ├── wds_imagenet1k
|   |   ├── wds_imagenetv2
|   |   ├── wds_imagenet_sketch
|   |   ├── wds_imagenet-a
|   |   ├── wds_imagenet-r
|   |   ├── wds_imagenet-o
|   |   └── wds_objectnet
```
4. Command:
```bash
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh
```

## 📑 Citation

If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:
```bibtex
@article{damovl2024infcl,
  title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
  author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
  journal={arXiv preprint arXiv:2410.17243},
  year={2024},
  url={https://arxiv.org/abs/2410.12787}
}
```

## 👍 Acknowledgement
The codebase of Inf-CLIP is adapted from [**OpenCLIP**](https://github.com/mlfoundations/open_clip). We are also grateful for the following projects our Inf-CL arose from:
* [**OpenAI CLIP**](https://openai.com/index/clip/), [**img2dataset**](https://github.com/rom1504/img2dataset), [**CLIP-Benchmark**](https://github.com/LAION-AI/CLIP_benchmark).
* [**FlashAttention**](https://github.com/Dao-AILab/flash-attention), [**RingAttention**](https://github.com/haoliuhl/ringattention), [**RingFlashAttention**](https://github.com/zhuzilin/ring-flash-attention). 


## 🔒 License

This project is released under the Apache 2.0 license as found in the LICENSE file.
The service is a research preview intended for **non-commercial use ONLY**, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.


================================================
FILE: inf_cl/__init__.py
================================================
from .flash import cal_flash_loss
from .ring  import cal_ring_loss, cal_inf_loss

================================================
FILE: inf_cl/flash.py
================================================
import math

import torch
import torch.nn.functional as F
import numpy as np

import triton
import triton.language as tl


@triton.jit
def _prob_fwd_kernel(
    Q,
    K,
    LSE,
    nheads,
    seqlen_q,
    seqlen_k,
    BLOCK_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    # start index of sequence length
    start_m = tl.program_id(0)

    # initialize offsets
    ndims = nheads * BLOCK_HEADDIM
    offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    # Initialize pointers to Q, K, V
    q_ptrs = Q + ndims * offs_m[:, None]
    k_ptrs = K + ndims * offs_n[:, None]
    # initialize pointer to m and l
    lse_i    = tl.zeros([BLOCK_M], dtype=tl.float32)
    m_i      = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

    # loop over k, v and update accumulator
    end_n = seqlen_k
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        for off_h in range(nheads):
            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
            # -- fetch q and k of a single head ----
            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
            # -- compute qk ----
            qk += tl.dot(q, tl.trans(k))

        # Trying to combine the two masks seem to make the result wrong
        m_ij = tl.maximum(tl.max(qk, 1), m_i)
        p = tl.exp(qk - m_ij[:, None])
        # Fix out of bound access
        p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0)
        # -- update statistics
        lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1)
        m_i = m_ij

    lse_i = m_i + tl.log(lse_i)
    # mask out the padded values
    lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0)

    tl.store(LSE + offs_m, lse_i)


@triton.jit
def _dq_prob_bwd_kernel(
    Q,
    K,
    dQ,
    LSE,
    dLSE,
    nheads,
    seqlen_q,
    seqlen_k,
    BLOCK_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
    # start index of sequence length
    start_m = tl.program_id(0)

    # initialize offsets
    ndims = nheads * BLOCK_HEADDIM
    offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    # Initialize pointers to Q, K, V
    q_ptrs  = Q  + ndims * offs_m[:, None]
    dq_ptrs = dQ + ndims * offs_m[:, None]
    k_ptrs  = K  + ndims * offs_n[:, None]
    # setting lse
    lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0)
    dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0)

    # loop over k, v and update accumulator
    end_n = seqlen_k        
    for start_n in range(0, end_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        for off_h in range(nheads):
            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
            # -- fetch q and k of a single head ----
            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
            # -- compute qk ----
            qk += tl.dot(q, tl.trans(k))

        qk_grad = tl.exp(qk - lse[:, None])
        qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0)
        qk_grad = qk_grad * dlse[:, None]
        qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
        for off_h in range(nheads):
            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
            # -- fetch q and k of a single head ----
            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
            # -- compute q grad ----
            # NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch
            # A solution for this problem
            # Refer to issue: https://github.com/triton-lang/triton/issues/4574
            # if allow_tf32:
            k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1)
            q_grad = tl.dot(qk_grad, k)
            # Another solution for this problem
            # Refer to https://github.com/triton-lang/triton/issues/376
            # q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)
            # -- store dq ----
            dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
            tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q)


@triton.jit
def _dk_prob_bwd_kernel(
    Q,
    K,
    dK,
    LSE,
    dLSE,
    nheads,
    seqlen_q,
    seqlen_k,
    BLOCK_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
    # start index of sequence length
    start_n = tl.program_id(0)

    # initialize offsets
    ndims = nheads * BLOCK_HEADDIM
    offs_m = tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N
    offs_d = tl.arange(0, BLOCK_HEADDIM)

    # Initialize pointers to Q, K, V
    q_ptrs  = Q  + ndims * offs_m[:, None]
    k_ptrs  = K  + ndims * offs_n[:, None]
    dk_ptrs = dK + ndims * offs_n[:, None]

    # loop over q and update accumulator
    end_m = seqlen_q        
    for start_m in range(0, end_m, BLOCK_M):
        start_m = tl.multiple_of(start_m, BLOCK_M)

        # setting lse
        lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)
        dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        for off_h in range(nheads):
            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
            # -- fetch q and k of a single head ----
            q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0)
            k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
            # -- compute qk ----
            qk += tl.dot(q, tl.trans(k))

        qk_grad = tl.exp(qk - lse[:, None])
        qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0)
        qk_grad = qk_grad * dlse[:, None]
        qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
        for off_h in range(nheads):
            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
            # -- fetch q and k of a single head ----
            q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0)
            k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
            # -- compute k grad ----
            q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1)
            k_grad = tl.dot(tl.trans(qk_grad), q)
            # k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32))
            # -- store dk ----
            dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
            tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k)


def _flash_prob_forward(q, k):
    # shape constraints
    seqlen_q, nheads, d = q.shape
    seqlen_k, _, _ = k.shape
    assert k.shape == (seqlen_k, nheads, d)
    # assert d <= 128, "FlashAttention only support head dimensions up to 128"
    assert q.dtype == k.dtype, "All tensors must have the same type"
    # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
    assert q.is_cuda and k.is_cuda

    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
    lse = torch.empty((seqlen_q_rounded), device=q.device, dtype=torch.float32)

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
    BLOCK_M = 64
    BLOCK_N = 64
    num_warps = 8
    num_stages = 1
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
    _prob_fwd_kernel[grid](
        q,
        k,
        lse,
        nheads,
        seqlen_q,
        seqlen_k,
        BLOCK_HEADDIM,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    lse = lse[:seqlen_q]

    return lse


def _flash_prob_backward(q, k, lse, dlse):
    # shape constraints
    seqlen_q, nheads, d = q.shape
    seqlen_k, _, _ = k.shape
    assert k.shape == (seqlen_k, nheads, d)
    # assert d <= 128, "FlashAttention only support head dimensions up to 128"
    assert q.dtype == k.dtype, "All tensors must have the same type"
    # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
    assert q.is_cuda and k.is_cuda

    dq = torch.zeros_like(q, dtype=torch.float32)
    dk = torch.zeros_like(k, dtype=torch.float32)

    q = q.contiguous()
    k = k.contiguous()
    dlse = dlse.contiguous()

    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
    BLOCK_M = 64
    BLOCK_N = 64
    num_warps = 8
    num_stages = 1
    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
    _dq_prob_bwd_kernel[grid](
        q,
        k,
        dq,
        lse,
        dlse,
        nheads,
        seqlen_q,
        seqlen_k,
        BLOCK_HEADDIM,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    BLOCK_N = BLOCK_M
    BLOCK_M = BLOCK_N
    grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), 1)
    _dk_prob_bwd_kernel[grid](
        q,
        k,
        dk,
        lse,
        dlse,
        nheads,
        seqlen_q,
        seqlen_k,
        BLOCK_HEADDIM,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    dq = dq[:seqlen_q]
    dk = dk[:seqlen_k]

    return dq, dk


class FlashProb(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k):
        lse = _flash_prob_forward(q, k)
        ctx.save_for_backward(q, k, lse)

        return lse

    @staticmethod
    def backward(ctx, dlse):
        q, k, lse = ctx.saved_tensors
        dq, dk = _flash_prob_backward(q, k, lse, dlse)

        return dq, dk


def _cal_flash_loss(q, k, labels, head_dim=256):
    bq = q.shape[0]
    bk = k.shape[0]
    # NOTE: logits forward or backward should keep fp32 for better precision
    q = q.view(bq, -1, head_dim).float()
    k = k.view(bk, -1, head_dim).float()

    lse = FlashProb.apply(q, k)
    numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
    loss = -numerator + lse

    return loss


def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):
    if labels is None:
        labels = torch.arange(q.shape[0], device=q.device)
    if scale is None:
        scale = 1.0
    return _cal_flash_loss(scale * q, k, labels, head_dim)


if __name__ == '__main__':
    import time

    # Parameters
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q, num_heads * d_model), dtype=torch.float32, device="cuda") # Query
    k = torch.rand((seq_length_k, num_heads * d_model), dtype=torch.float32, device="cuda") # Key
    l = torch.ones([], device="cuda") * np.log(1 / 0.02); l.requires_grad = True

    q = F.normalize(q, p=2, dim=-1); q.requires_grad = True
    k = F.normalize(k, p=2, dim=-1); k.requires_grad = True

    q1 = q.clone().detach().requires_grad_(True)
    k1 = k.clone().detach().requires_grad_(True)
    l1 = l.clone().detach().requires_grad_(True)

    labels = torch.arange(seq_length_q).cuda()

    for i in range(1000):

        # A. torch gradient
        start = time.time()
        qk = torch.einsum("md,nd->mn", l.exp() * q, k)
        loss = F.cross_entropy(qk, labels, reduction="mean")
        loss.backward()
        end = time.time()

        # B. triton gradient
        start1 = time.time()
        loss1 = cal_flash_loss(q1, k1, labels, l1.exp())
        loss1 = loss1.mean()
        loss1.backward()
        end1 = time.time()

        print("========= Difference =========")
        print(end - start, end1 - start1, l.grad, l1.grad)
        print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))

        q.grad = None; k.grad = None; l.grad = None
        q1.grad = None; k1.grad = None; l1.grad = None


================================================
FILE: inf_cl/ring.py
================================================
import os
import math
import random

import torch
import torch.distributed as dist
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
import numpy as np

import triton
import triton.language as tl

from .flash import _flash_prob_forward, _flash_prob_backward, _cal_flash_loss


class RingComm:

    def __init__(self, process_group: dist.ProcessGroup):
        self._process_group = process_group
        self._ops = []
        self.rank = dist.get_rank(self._process_group)
        self.world_size = dist.get_world_size(self._process_group)
        self._reqs = None

        self.send_rank = (self.rank + 1) % self.world_size
        self.recv_rank = (self.rank - 1) % self.world_size
        # print(f'rank: {self.rank}, send_rank: {self.send_rank}, recv_rank: {self.recv_rank}')
        if process_group is not None:
            self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
            self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)

    def send_recv(self, to_send, recv_tensor = None):
        if recv_tensor is None:
            res = torch.empty_like(to_send)
        else:
            res = recv_tensor

        send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
        recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
        self._ops.append(send_op)
        self._ops.append(recv_op)
        return res

    def commit(self):
        if self._reqs is not None:
            raise RuntimeError("commit called twice")
        self._reqs = dist.batch_isend_irecv(self._ops)

    def wait(self):
        if self._reqs is None:
            raise RuntimeError("wait called before commit")
        for req in self._reqs:
            req.wait()
        self._reqs = None
        self._ops = []


class GradientGather(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x

    @staticmethod
    def backward(ctx, dx):
        dist.all_reduce(dx)
        return dx


class RingProb(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, group):
        rank = dist.get_rank()
        k = k.contiguous()
        comm = RingComm(group)

        colle = [q, k]

        lse = None
        next_k = None
        for step in range(comm.world_size):
            if step + 1 != comm.world_size:
                next_k: torch.Tensor = comm.send_recv(k)
                comm.commit()

            # vanilla lse
            qk = torch.einsum("mhd,nhd->mn", q, k)
            block_lse = torch.log(torch.exp(qk).sum(dim=-1))

            if step == 0:
                lse = block_lse
            else:
                lse = lse - F.logsigmoid(lse - block_lse)

            if step + 1 != comm.world_size:
                comm.wait()
                k = next_k

        # this should be out_padded
        colle.append(lse)
        ctx.save_for_backward(*colle)
        ctx.group = group
        return lse

    @staticmethod
    def backward(ctx, dlse):
        rank = dist.get_rank()
        q, k, lse = ctx.saved_tensors
        k_comm = RingComm(ctx.group)
        d_k_comm = RingComm(ctx.group)
        dq, dk = None, None
        next_dk = None

        block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
        block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)

        next_dk, next_k = None, None

        for step in range(k_comm.world_size):
            if step + 1 != k_comm.world_size:
                next_k = k_comm.send_recv(k)
                k_comm.commit()

            # vanilla gradient calculation
            qk = torch.einsum("mhd,nhd->mn", q, k)
            qk_grad = torch.exp(qk - lse[:, None]).float()
            qk_grad = qk_grad * dlse[:, None]
            block_dq_buffer = torch.einsum("mn,nhd->mhd", qk_grad, k.float())
            block_dk_buffer = torch.einsum("nm,mhd->nhd", qk_grad.T, q.float())

            if step == 0:
                dq = block_dq_buffer
                dk = block_dk_buffer
            else:
                dq += block_dq_buffer
                d_k_comm.wait()
                dk = block_dk_buffer + next_dk

            if step + 1 != k_comm.world_size:
                k_comm.wait()
                k = next_k

            next_dk = d_k_comm.send_recv(dk)
            d_k_comm.commit()

        d_k_comm.wait()

        return dq, next_dk, None
    

class InfProb(torch.autograd.Function):

    @staticmethod
    def forward(ctx, q, k, group):
        rank = dist.get_rank()
        k = k.contiguous()
        comm = RingComm(group)

        colle = [q, k]

        lse = None
        next_k = None
        for step in range(comm.world_size):
            if step + 1 != comm.world_size:
                next_k: torch.Tensor = comm.send_recv(k)
                comm.commit()

            # flash lse
            block_lse = _flash_prob_forward(q, k)

            if step == 0:
                lse = block_lse
            else:
                lse = lse - F.logsigmoid(lse - block_lse)

            if step + 1 != comm.world_size:
                comm.wait()
                k = next_k

        # this should be out_padded
        colle.append(lse)
        ctx.save_for_backward(*colle)
        ctx.group = group
        return lse

    @staticmethod
    def backward(ctx, dlse):
        rank = dist.get_rank()
        q, k, lse = ctx.saved_tensors
        k_comm = RingComm(ctx.group)
        d_k_comm = RingComm(ctx.group)
        dq, dk = None, None
        next_dk = None

        block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
        block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)

        next_dk, next_k = None, None

        for step in range(k_comm.world_size):
            if step + 1 != k_comm.world_size:
                next_k = k_comm.send_recv(k)
                k_comm.commit()

            # flash gradient calculation
            block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse)

            if step == 0:
                dq = block_dq_buffer
                dk = block_dk_buffer
            else:
                dq += block_dq_buffer
                d_k_comm.wait()
                dk = block_dk_buffer + next_dk

            if step + 1 != k_comm.world_size:
                k_comm.wait()
                k = next_k

            next_dk = d_k_comm.send_recv(dk)
            d_k_comm.commit()

        d_k_comm.wait()

        return dq, next_dk, None


def set_seed(rank, seed=42):
    seed = rank + seed
    random.seed(seed)             
    torch.manual_seed(seed)      
    torch.cuda.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed) 


def _cal_ring_loss(q, k, labels, head_dim=256):
    bq = q.shape[0]
    bk = k.shape[0]
    q = q.view(bq, -1, head_dim).float()
    k = k.view(bk, -1, head_dim).float()

    lse = RingProb.apply(q, k, None)
    numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
    loss = -numerator + lse

    return loss


def _cal_inf_loss(q, k, labels, head_dim=256):
    bq = q.shape[0]
    bk = k.shape[0]
    q = q.view(bq, -1, head_dim).float()
    k = k.view(bk, -1, head_dim).float()

    lse = InfProb.apply(q, k, None)
    numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
    loss = -numerator + lse

    return loss


def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):
    """The triton implementation of the ring-cl.

    Args:
        q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
        k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
        labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
        scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
        head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.

    """

    if labels is None:
        labels = torch.arange(q.shape[0]).to(q.device)
    if scale is None:
        scale = 1.0
    else:
        scale = GradientGather.apply(scale)
    if torch.distributed.is_initialized():
        return _cal_ring_loss(scale * q, k, labels, head_dim).mean()
    else:
        return _cal_flash_loss(scale * q, k, labels, head_dim).mean()


def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):
    """The triton implementation of the inf-cl.

    Args:
        q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
        k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
        labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
        scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
        head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.

    """

    if labels is None:
        labels = torch.arange(q.shape[0]).to(q.device)
    if scale is None:
        scale = 1.0
    else:
        scale = GradientGather.apply(scale)
    if torch.distributed.is_initialized():
        return _cal_inf_loss(scale * q, k, labels, head_dim).mean()
    else:
        return _cal_flash_loss(scale * q, k, labels, head_dim).mean()


if __name__ == "__main__":
    import time

    dist.init_process_group("nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.cuda.set_device(f'cuda:{os.environ["LOCAL_RANK"]}')

    # Parameters
    dtype = torch.float32
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
    l = torch.ones([], dtype=dtype, device="cuda") * np.log(1 / 0.07); l = l.requires_grad_() # Logit scale

    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key

    q1 = q.clone().detach().requires_grad_()
    k1 = k.clone().detach().requires_grad_()
    l1 = l.clone().detach().requires_grad_()

    for i in range(1000):
        # A. local torch gradient
        start = time.time()
        # A.1. gather q, k
        gathered_q = [torch.zeros_like(q) for _ in range(world_size)]
        gathered_k = [torch.zeros_like(k) for _ in range(world_size)]
        dist.all_gather(gathered_q, q)
        dist.all_gather(gathered_k, k)
        gathered_q[rank] = q
        gathered_k[rank] = k
        all_q = torch.cat(gathered_q, dim=0)
        all_k = torch.cat(gathered_k, dim=0)
        # A.2. calculating qk logits
        qk = torch.einsum("md,nd->mn", l.exp() * all_q, all_k)
        kq = qk.T
        _labels = torch.arange(seq_length_q).to(q.device)
        # A.3. calculating loss
        loss_i2t = F.cross_entropy(qk, _labels, reduction="mean")
        loss_t2i = F.cross_entropy(kq, _labels, reduction="mean")
        # A.4. scaling loss to normal value
        scale_factor = (all_q.shape[0] / q.shape[0])
        loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor
        loss.backward()
        show_loss = loss.detach().clone()
        dist.all_reduce(show_loss)
        show_loss = show_loss / (world_size * scale_factor)
        end = time.time()

        dist.barrier()

        # B. triton implementation
        start1 = time.time()
        # labels = torch.arange(seq_length_q // world_size).to(q.device)
        loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp())
        loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp())
        loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5
        loss1.backward()
        end1 = time.time()

        dist.barrier()

        if rank == 0:
            print(rank, end - start, end1 - start1, loss, show_loss, loss1)
            print(l.grad, l1.grad, torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))

        q.grad = None; k.grad = None; l.grad = None
        q1.grad = None; k1.grad = None; l1.grad = None


================================================
FILE: inf_clip/__init__.py
================================================
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD

from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint

from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
    get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained

from .models.tokenizer import SimpleTokenizer, tokenize, decode
from .models.transform import image_transform, AugmentationCfg
from .models.coca_arch import CoCa
from .models.clip_arch import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
    convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
    get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss

from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES


================================================
FILE: inf_clip/constants.py
================================================
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)


================================================
FILE: inf_clip/factory.py
================================================
import json
import logging
import os
import re
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD

from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
    list_pretrained_tags_by_model, download_pretrained_from_hf, convert_state_dict

from .models.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .models.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .models.clip_arch import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
    resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .models.coca_arch import CoCa
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, FlashClipLoss, RingClipLoss, InfClipLoss, DiscoClipLoss


HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs


def _natural_key(string_):
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def _rescan_model_configs():
    global _MODEL_CONFIGS

    config_ext = ('.json',)
    config_files = []
    for config_path in _MODEL_CONFIG_PATHS:
        if config_path.is_file() and config_path.suffix in config_ext:
            config_files.append(config_path)
        elif config_path.is_dir():
            for ext in config_ext:
                config_files.extend(config_path.glob(f'*{ext}'))

    for cf in config_files:
        with open(cf, 'r') as f:
            model_cfg = json.load(f)
            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
                _MODEL_CONFIGS[cf.stem] = model_cfg

    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}


_rescan_model_configs()  # initial populate of model config registry


def list_models():
    """ enumerate available model architectures based on config files """
    return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
    """ add model config path or file and update registry """
    if not isinstance(path, Path):
        path = Path(path)
    _MODEL_CONFIG_PATHS.append(path)
    _rescan_model_configs()


def get_model_config(model_name):
    if model_name in _MODEL_CONFIGS:
        return deepcopy(_MODEL_CONFIGS[model_name])
    else:
        return None


def _get_hf_config(model_id, cache_dir=None):
    config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    return config


def get_tokenizer(
        model_name: str = '',
        context_length: Optional[int] = None,
        **kwargs,
):
    if model_name.startswith(HF_HUB_PREFIX):
        model_name = model_name[len(HF_HUB_PREFIX):]
        try:
            config = _get_hf_config(model_name)['model_cfg']
        except Exception:
            tokenizer = HFTokenizer(
                model_name,
                context_length=context_length or DEFAULT_CONTEXT_LENGTH,
                **kwargs,
            )
            return tokenizer
    else:
        config = get_model_config(model_name)
        assert config is not None, f"No valid model config found for {model_name}."

    text_config = config.get('text_cfg', {})
    if 'tokenizer_kwargs' in text_config:
        tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
    else:
        tokenizer_kwargs = kwargs

    if context_length is None:
        context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)

    if 'hf_tokenizer_name' in text_config:
        tokenizer = HFTokenizer(
            text_config['hf_tokenizer_name'],
            context_length=context_length,
            **tokenizer_kwargs,
        )
    else:
        tokenizer = SimpleTokenizer(
            context_length=context_length,
            **tokenizer_kwargs,
        )

    return tokenizer


def load_state_dict(checkpoint_path: str, map_location='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif isinstance(checkpoint, torch.jit.ScriptModule):
        state_dict = checkpoint.state_dict()
        for key in ["input_resolution", "context_length", "vocab_size"]:
            state_dict.pop(key, None)
    else:
        state_dict = checkpoint
    if next(iter(state_dict.items()))[0].startswith('module'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    return state_dict


def load_checkpoint(
        model: Union[CLIP, CustomTextCLIP],
        checkpoint_path: str,
        strict: bool = True,
):
    if Path(checkpoint_path).suffix in ('.npz', '.npy'):
        # Separate path loading numpy big_vision (SigLIP) weights
        from open_clip.pretrained import load_big_vision_weights
        load_big_vision_weights(model, checkpoint_path)
        return {}

    state_dict = load_state_dict(checkpoint_path)

    # Detect & convert 3rd party state_dicts -> open_clip
    state_dict = convert_state_dict(model, state_dict)

    # Detect old format and make compatible with new format
    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
        state_dict = convert_to_custom_text_state_dict(state_dict)

    # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
    if 'logit_bias' not in state_dict and model.logit_bias is not None:
        state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])

    # Certain text transformers no longer expect position_ids after transformers==4.31
    position_id_key = 'text.transformer.embeddings.position_ids'
    if position_id_key in state_dict and not hasattr(model, position_id_key):
        del state_dict[position_id_key]

    resize_pos_embed(state_dict, model)
    resize_text_pos_embed(state_dict, model)

    # Finally, load the massaged state_dict into model
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
    return incompatible_keys


def create_model(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        force_preprocess_cfg: Optional[Dict[str, Any]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = False,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
        **model_kwargs,
):
    force_preprocess_cfg = force_preprocess_cfg or {}
    preprocess_cfg = asdict(PreprocessCfg())
    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
    if has_hf_hub_prefix:
        model_id = model_name[len(HF_HUB_PREFIX):]
        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
        config = _get_hf_config(model_id, cache_dir)
        preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
        model_cfg = config['model_cfg']
        pretrained_hf = False  # override, no need to load original HF text weights
    else:
        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
        checkpoint_path = None
        model_cfg = None

    if isinstance(device, str):
        device = torch.device(device)

    if pretrained and pretrained.lower() == 'openai':
        logging.info(f'Loading pretrained {model_name} from OpenAI.')
        model = load_openai_model(
            model_name,
            precision=precision,
            device=device,
            cache_dir=cache_dir,
        )
    else:
        model_cfg = model_cfg or get_model_config(model_name)
        if model_cfg is not None:
            logging.info(f'Loaded {model_name} model config.')
        else:
            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
            raise RuntimeError(f'Model config for {model_name} not found.')

        if force_quick_gelu:
            # override for use of QuickGELU on non-OpenAI transformer models
            model_cfg["quick_gelu"] = True

        if force_patch_dropout is not None and force_patch_dropout != False:
            # override the default patch dropout value
            model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout

        if force_image_size is not None and force_image_size != False:
            # override model config's image size
            model_cfg["vision_cfg"]["image_size"] = force_image_size

        is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
        if pretrained_image:
            if is_timm_model:
                # pretrained weight loading for timm models set via vision_cfg
                model_cfg['vision_cfg']['timm_model_pretrained'] = True
            else:
                assert False, 'pretrained image towers currently only supported for timm models'

        # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
        cast_dtype = get_cast_dtype(precision)
        is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
        if is_hf_model:
            # load pretrained weights for HF text model IFF no CLIP weights being loaded
            # NOTE: disable pretrained_hf arguments.
            # model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
            model_cfg['text_cfg']['hf_model_pretrained'] = model_cfg['text_cfg']['hf_model_pretrained']
            # and not pretrained
        custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

        model_cfg = dict(model_cfg, **model_kwargs)  # merge cfg dict w/ kwargs (kwargs overrides cfg)
        model_arch = model_cfg.pop("arch", "CLIP")
        if custom_text:
            if "CLIP" in model_arch:
                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
            elif "LiT" in model_arch:
                model = LiT(**model_cfg, cast_dtype=cast_dtype)
            elif "CoCa" in model_arch or "multimodal_cfg" in model_cfg:
                model = CoCa(**model_cfg, cast_dtype=cast_dtype)
            else:
                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
        else:
            model = CLIP(**model_cfg, cast_dtype=cast_dtype)

        if precision in ("fp16", "bf16"):
            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
            # manual mixed precision that matches original OpenAI behaviour
            if is_timm_model:
                # FIXME this is a bit janky, create timm based model in low-precision and
                # then cast only LayerNormFp32 instances back to float32 so they don't break.
                # Why? The convert_weights_to_lp fn only works with native models.
                model.to(device=device, dtype=dtype)
                from .transformer import LayerNormFp32

                def _convert_ln(m):
                    if isinstance(m, LayerNormFp32):
                        m.weight.data = m.weight.data.to(torch.float32)
                        m.bias.data = m.bias.data.to(torch.float32)
                model.apply(_convert_ln)
            else:
                model.to(device=device)
                convert_weights_to_lp(model, dtype=dtype)
        elif precision in ("pure_fp16", "pure_bf16"):
            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
            model.to(device=device, dtype=dtype)
        else:
            model.to(device=device)

        pretrained_loaded = False
        if pretrained:
            checkpoint_path = ''
            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
            if pretrained_cfg:
                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
                preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
            elif os.path.exists(pretrained):
                checkpoint_path = pretrained

            if checkpoint_path:
                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
                load_checkpoint(model, checkpoint_path)
            else:
                error_str = (
                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'
                    f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
                logging.warning(error_str)
                raise RuntimeError(error_str)
            pretrained_loaded = True
        elif has_hf_hub_prefix:
            logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
            load_checkpoint(model, checkpoint_path)
            pretrained_loaded = True

        if require_pretrained and not pretrained_loaded:
            # callers of create_model_from_pretrained always expect pretrained weights
            raise RuntimeError(
                f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')

    if output_dict and hasattr(model, "output_dict"):
        model.output_dict = True

    if jit:
        model = torch.jit.script(model)

    # set image preprocessing configuration in model attributes for convenience
    if getattr(model.visual, 'image_size', None) is not None:
        # use image_size set on model creation (via config or force_image_size arg)
        force_preprocess_cfg['size'] = model.visual.image_size
    set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))

    return model


def create_loss(args):
    if args.distill_model:
        return DistillClipLoss(
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    elif "coca" in args.model.lower():
        return CoCaLoss(
            caption_loss_weight=args.coca_caption_loss_weight,
            clip_loss_weight=args.coca_contrastive_loss_weight,
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    elif args.siglip:
        assert not args.horovod, "Horovod not currently supported for SigLip"
        return SigLipLoss(
            rank=args.rank,
            world_size=args.world_size,
        )
    elif args.flashloss:
        return FlashClipLoss(
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    elif args.ringloss:
        return RingClipLoss(
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod
        )
    elif args.infloss:
        return InfClipLoss(
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod
        )
    elif args.discoloss:
        return DiscoClipLoss(
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod
        )
    return ClipLoss(
        local_loss=args.local_loss,
        gather_with_grad=args.gather_with_grad,
        cache_labels=True,
        rank=args.rank,
        world_size=args.world_size,
        use_horovod=args.horovod,
    )


def create_model_and_transforms(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        image_interpolation: Optional[str] = None,
        image_resize_mode: Optional[str] = None,  # only effective for inference
        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = False,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        **model_kwargs,
):
    force_preprocess_cfg = merge_preprocess_kwargs(
        {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)

    model = create_model(
        model_name,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_patch_dropout=force_patch_dropout,
        force_image_size=force_image_size,
        force_preprocess_cfg=force_preprocess_cfg,
        pretrained_image=pretrained_image,
        pretrained_hf=pretrained_hf,
        cache_dir=cache_dir,
        output_dict=output_dict,
        **model_kwargs,
    )

    pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)

    preprocess_train = image_transform_v2(
        pp_cfg,
        is_train=True,
        aug_cfg=aug_cfg,
    )
    preprocess_val = image_transform_v2(
        pp_cfg,
        is_train=False,
    )

    return model, preprocess_train, preprocess_val


def create_model_from_pretrained(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        image_interpolation: Optional[str] = None,
        image_resize_mode: Optional[str] = None,  # only effective for inference
        return_transform: bool = True,
        cache_dir: Optional[str] = None,
        **model_kwargs,
):
    force_preprocess_cfg = merge_preprocess_kwargs(
        {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)

    model = create_model(
        model_name,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_image_size=force_image_size,
        force_preprocess_cfg=force_preprocess_cfg,
        cache_dir=cache_dir,
        require_pretrained=True,
        **model_kwargs,
    )

    if not return_transform:
        return model

    preprocess = image_transform_v2(
        PreprocessCfg(**model.visual.preprocess_cfg),
        is_train=False,
    )

    return model, preprocess


================================================
FILE: inf_clip/model_configs/EVA01-g-14-plus.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva_giant_patch14_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA01-g-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva_giant_patch14_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA02-B-16.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva02_base_patch16_clip_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA02-E-14-plus.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva02_enormous_patch14_clip_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1280,
        "heads": 20,
        "layers": 32
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA02-E-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva02_enormous_patch14_clip_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA02-L-14-336.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 336,
        "timm_model_name": "eva02_large_patch14_clip_336",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/EVA02-L-14.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "eva02_large_patch14_clip_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/LiT-B-16.json
================================================
{
    "arch": "LiT-B-16",
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "vit_base_patch16_224",
        "timm_model_pretrained": true,
        "timm_pool": "token",
        "timm_proj": "linear"
    },
    "text_cfg": {
        "hf_tokenizer_name": "bert-base-uncased",
        "hf_model_name": "bert-base-uncased",
        "hf_model_pretrained": true,
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/LiT-B-32.json
================================================
{
    "arch": "LiT-B-32",
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "vit_base_patch32_224",
        "timm_model_pretrained": true,
        "timm_pool": "token",
        "timm_proj": "linear"
    },
    "text_cfg": {
        "hf_tokenizer_name": "bert-base-uncased",
        "hf_model_name": "bert-base-uncased",
        "hf_model_pretrained": true,
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/LiT-L-16.json
================================================
{
    "arch": "LiT-L-16",
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "vit_large_patch16_224",
        "timm_model_pretrained": true,
        "timm_pool": "token",
        "timm_proj": "linear"
    },
    "text_cfg": {
        "hf_tokenizer_name": "bert-large-uncased",
        "hf_model_name": "bert-large-uncased",
        "hf_model_pretrained": true,
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/MobileCLIP-B.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "vit_base_mci_224",
        "timm_model_pretrained": false,
        "timm_pool": "token",
        "timm_proj": null,
        "timm_drop": 0.0,
        "timm_drop_path": 0.0,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "no_causal_mask": false
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/MobileCLIP-S1.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "fastvit_mci1",
        "timm_model_pretrained": false,
        "timm_pool": "avg",
        "timm_proj": null,
        "timm_drop": 0.0,
        "timm_drop_path": 0.0,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "no_causal_mask": true
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/MobileCLIP-S2.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "fastvit_mci2",
        "timm_model_pretrained": false,
        "timm_pool": "avg",
        "timm_proj": null,
        "timm_drop": 0.0,
        "timm_drop_path": 0.0,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "no_causal_mask": true
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/RN101-quickgelu.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            23,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/RN101.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            23,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/RN50-quickgelu.json
================================================
{
    "embed_dim": 1024,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            6,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}


================================================
FILE: inf_clip/model_configs/RN50.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            6,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/RN50x16.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 384,
        "layers": [
            6,
            8,
            18,
            8
        ],
        "width": 96,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/RN50x4.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 288,
        "layers": [
            4,
            6,
            10,
            6
        ],
        "width": 80,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/RN50x64.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 448,
        "layers": [
            3,
            15,
            36,
            10
        ],
        "width": 128,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-256.json
================================================
{
    "embed_dim": 768,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 256,
        "timm_model_name": "vit_base_patch16_siglip_256",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-384.json
================================================
{
    "embed_dim": 768,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 384,
        "timm_model_name": "vit_base_patch16_siglip_384",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-512.json
================================================
{
    "embed_dim": 768,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 512,
        "timm_model_name": "vit_base_patch16_siglip_512",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json
================================================
{
    "embed_dim": 768,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 256,
        "timm_model_name": "vit_base_patch16_siglip_256",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 250000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP.json
================================================
{
    "embed_dim": 768,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "vit_base_patch16_siglip_224",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-plus-240.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 240,
        "layers": 12,
        "width": 896,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-plus.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 896,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16-quickgelu.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-16.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-32-256.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 256,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}


================================================
FILE: inf_clip/model_configs/ViT-B-32-plus-256.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 256,
        "layers": 12,
        "width": 896,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-32-quickgelu.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-14-378-quickgelu.json
================================================
{
    "embed_dim": 1024,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 378,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA-336.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 336,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 1024,
        "heads": 16,
        "layers": 24,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 1024,
        "heads": 16,
        "layers": 24,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-14-quickgelu.json
================================================
{
    "embed_dim": 1024,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/ViT-H-16.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14-280.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 280,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14-336.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 336,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA-336.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 336,
        "layers": 24,
        "width": 1024,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 768,
        "heads": 12,
        "layers": 12,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14-quickgelu.json
================================================
{
    "embed_dim": 768,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-14.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-16-320.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 320,
        "layers": 24,
        "width": 1024,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-256.json
================================================
{
    "embed_dim": 1024,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 256,
        "timm_model_name": "vit_large_patch16_siglip_256",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 1024,
        "heads": 16,
        "layers": 24,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-384.json
================================================
{
    "embed_dim": 1024,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 384,
        "timm_model_name": "vit_large_patch16_siglip_384",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 1024,
        "heads": 16,
        "layers": 24,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-L-16.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-M-16-alt.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 16,
        "ls_init_value": 1e-4
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-M-16.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-M-32-alt.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-M-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-S-16-alt.json
================================================
{
    "embed_dim": 256,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 256,
        "heads": 4,
        "layers": 10
    }
}

================================================
FILE: inf_clip/model_configs/ViT-S-16.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-S-32-alt.json
================================================
{
    "embed_dim": 256,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 256,
        "heads": 4,
        "layers": 10
    }
}

================================================
FILE: inf_clip/model_configs/ViT-S-32.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json
================================================
{
    "embed_dim": 1152,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 384,
        "timm_model_name": "vit_so400m_patch14_siglip_384",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 64,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 1152,
        "heads": 16,
        "layers": 27,
        "mlp_ratio": 3.7362,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP.json
================================================
{
    "embed_dim": 1152,
    "init_logit_bias": -10,
    "custom_text": true,
    "vision_cfg": {
        "image_size": 224,
        "timm_model_name": "vit_so400m_patch14_siglip_224",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "context_length": 16,
        "vocab_size": 32000,
        "hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
        "tokenizer_kwargs": {
            "clean": "canonicalize"
        },
        "width": 1152,
        "heads": 16,
        "layers": 27,
        "mlp_ratio": 3.7362,
        "no_causal_mask": true,
        "proj_bias": true,
        "pool_type": "last",
        "norm_kwargs":{
            "eps": 1e-6
        }
    }
}

================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 336,
        "layers": 48,
        "width": 1664,
        "head_width": 104,
        "mlp_ratio": 4.9231,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 1280,
        "heads": 20,
        "layers": 32,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 224,
        "layers": 48,
        "width": 1664,
        "head_width": 104,
        "mlp_ratio": 4.9231,
        "patch_size": 14,
        "no_ln_pre": true,
        "pool_type": "avg",
        "final_ln_after_pool": true
    },
    "text_cfg": {
        "context_length": 32,
        "vocab_size": 32000,
        "hf_tokenizer_name": "bert-base-uncased",
        "tokenizer_kwargs": {
            "strip_sep_token": true
        },
        "width": 1280,
        "heads": 20,
        "layers": 32,
        "pool_type": "last",
        "no_causal_mask": true
    }
}

================================================
FILE: inf_clip/model_configs/ViT-bigG-14.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 224,
        "layers": 48,
        "width": 1664,
        "head_width": 104,
        "mlp_ratio": 4.9231,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1280,
        "heads": 20,
        "layers": 32
    }
}

================================================
FILE: inf_clip/model_configs/ViT-e-14.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 224,
        "layers": 56,
        "width": 1792,
        "head_width": 112,
        "mlp_ratio": 8.5715,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1280,
        "heads": 20,
        "layers": 36
    }
}

================================================
FILE: inf_clip/model_configs/ViT-g-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 40,
        "width": 1408,
        "head_width": 88,
        "mlp_ratio": 4.3637,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/ViTamin-B-LTT.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
      "timm_model_name": "vitamin_base_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 768,
      "heads": 12,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-B.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
      "timm_model_name": "vitamin_base_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 512,
      "heads": 8,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L-256.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
      "timm_model_name": "vitamin_large_256",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 256
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 768,
      "heads": 12,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L-336.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
      "timm_model_name": "vitamin_large_336",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 336
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 768,
      "heads": 12,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
      "timm_model_name": "vitamin_large_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 768,
      "heads": 12,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L2-256.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
      "timm_model_name": "vitamin_large2_256",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 256
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1024,
      "heads": 16,
      "layers": 24
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L2-336.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
      "timm_model_name": "vitamin_large2_336",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 336
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1024,
      "heads": 16,
      "layers": 24
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-L2.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
      "timm_model_name": "vitamin_large2_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1024,
      "heads": 16,
      "layers": 24
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-S-LTT.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
      "timm_model_name": "vitamin_small_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 768,
      "heads": 12,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-S.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
      "timm_model_name": "vitamin_small_224",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 224
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 384,
      "heads": 6,
      "layers": 12
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-XL-256.json
================================================
{
    "embed_dim": 1152,
    "vision_cfg": {
      "timm_model_name": "vitamin_xlarge_256",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 256
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1152,
      "heads": 16,
      "layers": 27
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-XL-336.json
================================================
{
    "embed_dim": 1152,
    "vision_cfg": {
      "timm_model_name": "vitamin_xlarge_336",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 336
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1152,
      "heads": 16,
      "layers": 27
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/ViTamin-XL-384.json
================================================
{
    "embed_dim": 1152,
    "vision_cfg": {
      "timm_model_name": "vitamin_xlarge_384",
      "timm_model_pretrained": false,
      "timm_pool": "",
      "timm_proj": "linear",
      "timm_drop": 0.0,
      "timm_drop_path": 0.1,
      "image_size": 256
    },
    "text_cfg": {
      "context_length": 77,
      "vocab_size": 49408,
      "width": 1152,
      "heads": 16,
      "layers": 27
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/coca_ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32,
        "attentional_pool": true,
        "attn_pooler_heads": 8,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "embed_cls": true,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "attn_pooler_heads": 8
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/coca_ViT-L-14.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14,
        "attentional_pool": true,
        "attn_pooler_heads": 8,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12,
        "embed_cls": true,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12,
        "attn_pooler_heads": 12
    },
    "custom_text": true
}


================================================
FILE: inf_clip/model_configs/coca_base.json
================================================
{
    "embed_dim": 512,
    "multimodal_cfg": {
        "width": 768,
        "context_length": 76,
        "vocab_size": 64000,
        "mlp_ratio": 4,
        "layers": 12,
        "dim_head": 64,
        "heads": 12,
        "n_queries": 256,
        "attn_pooler_heads": 8
    },
    "vision_cfg": {
        "image_size": 288,
        "layers": 12,
        "width": 768,
        "patch_size": 18,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 64000,
        "layers": 12,
        "heads": 12,
        "width": 768,
        "embed_cls": true,
        "output_tokens": true
    },
    "custom_text": true
}

================================================
FILE: inf_clip/model_configs/coca_roberta-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32,
        "output_tokens": true
    },
    "text_cfg": {
        "hf_model_name": "roberta-base",
        "hf_tokenizer_name": "roberta-base",
        "hf_proj_type": "linear",
        "width": 768,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "width": 768,
        "heads": 8,
        "layers": 12
    },
    "custom_text": true
}


================================================
FILE: inf_clip/model_configs/convnext_base.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_base_w.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_base_w_320.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_large.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_large_d.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "mlp",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 16
    }
}

================================================
FILE: inf_clip/model_configs/convnext_large_d_320.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "mlp",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 16
    }
}

================================================
FILE: inf_clip/model_configs/convnext_small.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "convnext_small",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_tiny.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_tiny",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/convnext_xlarge.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 20
    }
}

================================================
FILE: inf_clip/model_configs/convnext_xxlarge.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xxlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/convnext_xxlarge_320.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xxlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: inf_clip/model_configs/mt5-base-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "google/mt5-base",
        "hf_tokenizer_name": "google/mt5-base",
        "hf_pooler_type": "mean_pooler"
    }
}


================================================
FILE: inf_clip/model_configs/mt5-xl-ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "hf_model_name": "google/mt5-xl",
        "hf_tokenizer_name": "google/mt5-xl",
        "hf_pooler_type": "mean_pooler"
    }
}


================================================
FILE: inf_clip/model_configs/nllb-clip-base-siglip.json
================================================
{
    "embed_dim": 768,
    "custom_text": true,
    "init_logit_bias": -10,
    "vision_cfg": {
        "image_size": 384,
        "timm_model_name": "vit_base_patch16_siglip_384",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "hf_model_name": "facebook/nllb-200-distilled-600M",
        "hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/nllb-clip-base.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "facebook/nllb-200-distilled-600M",
        "hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/nllb-clip-large-siglip.json
================================================
{
    "embed_dim": 1152,
    "custom_text": true,
    "init_logit_bias": -10,
    "vision_cfg": {
        "image_size": 384,
        "timm_model_name": "vit_so400m_patch14_siglip_384",
        "timm_model_pretrained": false,
        "timm_pool": "map",
        "timm_proj": "none"
    },
    "text_cfg": {
        "hf_model_name": "facebook/nllb-200-distilled-1.3B",
        "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/nllb-clip-large.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "hf_model_name": "facebook/nllb-200-distilled-1.3B",
        "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
        "hf_proj_type": "linear",
        "hf_pooler_type": "cls_pooler"
    }
}

================================================
FILE: inf_clip/model_configs/roberta-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "roberta-base",
        "hf_tokenizer_name": "roberta-base",
        "hf_pooler_type": "mean_pooler"
    }
}


================================================
FILE: inf_clip/model_configs/swin_base_patch4_window7_224.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "swin_base_patch4_window7_224",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/vit_medium_patch16_gap_256.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "vit_medium_patch16_gap_256",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "vit_relpos_medium_patch16_cls_224",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "xlm-roberta-base",
        "hf_tokenizer_name": "xlm-roberta-base",
        "hf_pooler_type": "mean_pooler"
    }
}


================================================
FILE: inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "hf_model_name": "xlm-roberta-large",
        "hf_tokenizer_name": "xlm-roberta-large",
        "hf_pooler_type": "mean_pooler"
    }
}


================================================
FILE: inf_clip/models/clip_arch.py
================================================
""" CLIP Model

Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial

from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
    text_global_pool
from ..utils import to_2tuple


@dataclass
class CLIPVisionCfg:
    layers: Union[Tuple[int, int, int, int], int] = 12
    width: int = 768
    head_width: int = 64
    mlp_ratio: float = 4.0
    patch_size: int = 16
    image_size: Union[Tuple[int, int], int] = 224

    ls_init_value: Optional[float] = None  # layer scale initial value
    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
    attentional_pool: bool = False  # whether to use attentional pooler in the last embedding layer (overrides pool_type)
    attn_pooler_queries: int = 256  # n_queries for attentional pooler
    attn_pooler_heads: int = 8  # n heads for attentional_pooling
    no_ln_pre: bool = False  # disable pre transformer LayerNorm
    pos_embed_type: str = 'learnable'
    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling
    pool_type: str = 'tok'
    output_tokens: bool = False
    act_kwargs: Optional[dict] = None
    norm_kwargs: Optional[dict] = None

    timm_model_name: Optional[str] = None  # a valid model name overrides layers, width, patch_size
    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')
    timm_proj_bias: bool = False  # enable bias final projection
    timm_drop: float = 0.  # head dropout
    timm_drop_path: Optional[float] = None  # backbone stochastic depth


@dataclass
class CLIPTextCfg:
    context_length: int = 77
    vocab_size: int = 49408
    hf_tokenizer_name: Optional[str] = None
    tokenizer_kwargs: Optional[dict] = None

    width: int = 512
    heads: int = 8
    layers: int = 12
    mlp_ratio: float = 4.0
    ls_init_value: Optional[float] = None  # layer scale initial value
    embed_cls: bool = False
    pad_id: int = 0
    no_causal_mask: bool = False  # disable causal masking
    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling
    pool_type: str = 'argmax'
    proj_bias: bool = False
    output_tokens: bool = False
    act_kwargs: dict = None
    norm_kwargs: dict = None

    # HuggingFace specific text tower config
    hf_model_name: Optional[str] = None
    hf_model_pretrained: bool = True
    hf_proj_type: str = 'mlp'
    hf_pooler_type: str = 'mean_pooler'  # attentional pooling for HF models


def get_cast_dtype(precision: str):
    cast_dtype = None
    if precision == 'bf16':
        cast_dtype = torch.bfloat16
    elif precision == 'fp16':
        cast_dtype = torch.float16
    return cast_dtype


def get_input_dtype(precision: str):
    input_dtype = None
    if precision in ('bf16', 'pure_bf16'):
        input_dtype = torch.bfloat16
    elif precision in ('fp16', 'pure_fp16'):
        input_dtype = torch.float16
    return input_dtype


def _build_vision_tower(
        embed_dim: int,
        vision_cfg: CLIPVisionCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None
):
    if isinstance(vision_cfg, dict):
        vision_cfg = CLIPVisionCfg(**vision_cfg)

    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
    # memory efficient in recent PyTorch releases (>= 1.10).
    # NOTE: timm models always use native GELU regardless of quick_gelu flag.
    act_layer = QuickGELU if quick_gelu else nn.GELU

    if vision_cfg.timm_model_name:
        visual = TimmModel(
            vision_cfg.timm_model_name,
            pretrained=vision_cfg.timm_model_pretrained,
            pool=vision_cfg.timm_pool,
            proj=vision_cfg.timm_proj,
            proj_bias=vision_cfg.timm_proj_bias,
            drop=vision_cfg.timm_drop,
            drop_path=vision_cfg.timm_drop_path,
            patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
            embed_dim=embed_dim,
            image_size=vision_cfg.image_size,
        )
    elif isinstance(vision_cfg.layers, (tuple, list)):
        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
        visual = ModifiedResNet(
            layers=vision_cfg.layers,
            output_dim=embed_dim,
            heads=vision_heads,
            image_size=vision_cfg.image_size,
            width=vision_cfg.width,
        )
    else:
        vision_heads = vision_cfg.width // vision_cfg.head_width
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
        if vision_cfg.norm_kwargs:
            norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
        if vision_cfg.act_kwargs is not None:
            act_layer = partial(act_layer, **vision_cfg.act_kwargs)

        visual = VisionTransformer(
            image_size=vision_cfg.image_size,
            patch_size=vision_cfg.patch_size,
            width=vision_cfg.width,
            layers=vision_cfg.layers,
            heads=vision_heads,
            mlp_ratio=vision_cfg.mlp_ratio,
            ls_init_value=vision_cfg.ls_init_value,
            patch_dropout=vision_cfg.patch_dropout,
            attentional_pool=vision_cfg.attentional_pool,
            attn_pooler_queries=vision_cfg.attn_pooler_queries,
            attn_pooler_heads=vision_cfg.attn_pooler_heads,
            pos_embed_type=vision_cfg.pos_embed_type,
            no_ln_pre=vision_cfg.no_ln_pre,
            final_ln_after_pool=vision_cfg.final_ln_after_pool,
            pool_type=vision_cfg.pool_type,
            output_tokens=vision_cfg.output_tokens,
            output_dim=embed_dim,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

    return visual


def _build_text_tower(
        embed_dim: int,
        text_cfg: CLIPTextCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    if isinstance(text_cfg, dict):
        text_cfg = CLIPTextCfg(**text_cfg)

    if text_cfg.hf_model_name:
        text = HFTextEncoder(
            text_cfg.hf_model_name,
            output_dim=embed_dim,
            proj_type=text_cfg.hf_proj_type,
            pooler_type=text_cfg.hf_pooler_type,
            pretrained=text_cfg.hf_model_pretrained,
            output_tokens=text_cfg.output_tokens,
        )
    else:
        act_layer = QuickGELU if quick_gelu else nn.GELU
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
        if text_cfg.norm_kwargs:
            norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
        if text_cfg.act_kwargs is not None:
            act_layer = partial(act_layer, **text_cfg.act_kwargs)

        text = TextTransformer(
            context_length=text_cfg.context_length,
            vocab_size=text_cfg.vocab_size,
            width=text_cfg.width,
            heads=text_cfg.heads,
            layers=text_cfg.layers,
            mlp_ratio=text_cfg.mlp_ratio,
            ls_init_value=text_cfg.ls_init_value,
            output_dim=embed_dim,
            embed_cls=text_cfg.embed_cls,
            no_causal_mask=text_cfg.no_causal_mask,
            pad_id=text_cfg.pad_id,
            pool_type=text_cfg.pool_type,
            proj_bias=text_cfg.proj_bias,
            output_tokens=text_cfg.output_tokens,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
    return text


class CLIP(nn.Module):
    output_dict: torch.jit.Final[bool]
    arch_type: torch.jit.Final[str] = 'clip'

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            init_logit_scale: float = np.log(1 / 0.07),
            init_logit_bias: Optional[float] = None,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict

        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)

        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.transformer = text.transformer
        self.context_length = text.context_length
        self.vocab_size = text.vocab_size
        self.token_embedding = text.token_embedding
        self.positional_embedding = text.positional_embedding
        self.ln_final = text.ln_final
        self.text_projection = text.text_projection
        self.text_pool_type = text.pool_type
        self.register_buffer('attn_mask', text.attn_mask, persistent=False)

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        if init_logit_bias is not None:
            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
        else:
            self.logit_bias = None

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.transformer.grad_checkpointing = enable

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype)
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        x, _ = text_global_pool(x, text, self.text_pool_type)
        if self.text_projection is not None:
            if isinstance(self.text_projection, nn.Linear):
                x = self.text_projection(x)
            else:
                x = x @ self.text_projection

        return F.normalize(x, dim=-1) if normalize else x

    def get_logits(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        image_logits = self.logit_scale.exp() * image_features @ text_features.T
        if self.logit_bias is not None:
            image_logits += self.logit_bias
        text_logits = image_logits.T
        return image_logits, text_logits

    def forward(
            self,
            image: Optional[torch.Tensor] = None,
            text: Optional[torch.Tensor] = None,
    ):
        image_features = self.encode_image(image, normalize=True) if image is not None else None
        text_features = self.encode_text(text, normalize=True) if text is not None else None

        if self.output_dict:
            out_dict = {
                "image_features": image_features,
                "text_features": text_features,
                "logit_scale": self.logit_scale.exp()
            }
            if self.logit_bias is not None:
                out_dict['logit_bias'] = self.logit_bias
            return out_dict

        if self.logit_bias is not None:
            return image_features, text_features, self.logit_scale.exp(), self.logit_bias
        return image_features, text_features, self.logit_scale.exp()


class CustomTextCLIP(nn.Module):
    output_dict: torch.jit.Final[bool]
    arch_type: torch.jit.Final[str] = 'clip'

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            init_logit_scale: float = np.log(1 / 0.07),
            init_logit_bias: Optional[float] = None,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict
        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.context_length = self.text.context_length
        self.vocab_size = self.text.vocab_size
        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        if init_logit_bias is not None:
            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
        else:
            self.logit_bias = None

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        self.text.lock(unlocked_layers, freeze_layer_norm)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        features = self.text(text)
        return F.normalize(features, dim=-1) if normalize else features

    def get_logits(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        image_logits = self.logit_scale.exp() * image_features @ text_features.T
        if self.logit_bias is not None:
            image_logits += self.logit_bias
        text_logits = image_logits.T
        return image_logits, text_logits

    def forward(
            self,
            image: Optional[torch.Tensor] = None,
            text: Optional[torch.Tensor] = None,
    ):
        image_features = self.encode_image(image, normalize=True) if image is not None else None
        text_features = self.encode_text(text, normalize=True) if text is not None else None

        if self.output_dict:
            out_dict = {
                "image_features": image_features,
                "text_features": text_features,
                "logit_scale": self.logit_scale.exp()
            }
            if self.logit_bias is not None:
                out_dict['logit_bias'] = self.logit_bias
            return out_dict

        if self.logit_bias is not None:
            return image_features, text_features, self.logit_scale.exp(), self.logit_bias
        return image_features, text_features, self.logit_scale.exp()


def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
    """Convert applicable model parameters to low-precision (bf16 or fp16)"""

    def _convert_weights(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.to(dtype)
            if l.bias is not None:
                l.bias.data = l.bias.data.to(dtype)

        if isinstance(l, (nn.MultiheadAttention, Attention)):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.to(dtype)

        if isinstance(l, (CLIP, TextTransformer)):
            # convert text nn.Parameter projections
            attr = getattr(l, "text_projection", None)
            if attr is not None:
                attr.data = attr.data.to(dtype)

        if isinstance(l, VisionTransformer):
            # convert vision nn.Parameter projections
            attr = getattr(l, "proj", None)
            if attr is not None:
                attr.data = attr.data.to(dtype)

    model.apply(_convert_weights)


convert_weights_to_fp16 = convert_weights_to_lp  # backwards compat


# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
    if 'text_projection' in state_dict:
        # old format state_dict, move text tower -> .text
        new_state_dict = {}
        for k, v in state_dict.items():
            if any(k.startswith(p) for p in (
                'text_projection',
                'positional_embedding',
                'token_embedding',
                'transformer',
                'ln_final',
            )):
                k = 'text.' + k
            new_state_dict[k] = v
        return new_state_dict
    return state_dict


def build_model_from_openai_state_dict(
        state_dict: dict,
        quick_gelu=True,
        cast_dtype=torch.float16,
):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len(
            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_size = vision_patch_size * grid_size
    else:
        counts: list = [
            len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_size = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    vision_cfg = CLIPVisionCfg(
        layers=vision_layers,
        width=vision_width,
        patch_size=vision_patch_size,
        image_size=image_size,
    )
    text_cfg = CLIPTextCfg(
        context_length=context_length,
        vocab_size=vocab_size,
        width=transformer_width,
        heads=transformer_heads,
        layers=transformer_layers,
    )
    model = CLIP(
        embed_dim,
        vision_cfg=vision_cfg,
        text_cfg=text_cfg,
        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU
        cast_dtype=cast_dtype,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        state_dict.pop(key, None)
    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16
    model.load_state_dict(state_dict)
    return model.eval()


def trace_model(model, batch_size=256, device=torch.device('cpu')):
    model.eval()
    image_size = model.visual.image_size
    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
    model = torch.jit.trace_module(
        model,
        inputs=dict(
            forward=(example_images, example_text),
            encode_text=(example_text,),
            encode_image=(example_images,)
        ))
    model.visual.image_size = image_size
    return model


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
    # Rescale the grid of position embeddings when loading from state_dict
    old_pos_embed = state_dict.get('visual.positional_embedding', None)
    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
        return
    grid_size = to_2tuple(model.visual.grid_size)
    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
    if new_seq_len == old_pos_embed.shape[0]:
        return

    if extra_tokens:
        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
    else:
        pos_emb_tok, pos_emb_img = None, old_pos_embed
    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
    pos_emb_img = F.interpolate(
        pos_emb_img,
        size=grid_size,
        mode=interpolation,
        antialias=antialias,
        align_corners=False,
    )
    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
    if pos_emb_tok is not None:
        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
    else:
        new_pos_embed = pos_emb_img
    state_dict['visual.positional_embedding'] = new_pos_embed


def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
    old_pos_embed = state_dict.get('positional_embedding', None)
    if old_pos_embed is None:
        return
    # FIXME add support for text cls_token
    model_pos_embed = getattr(model, 'positional_embedding', None)
    if model_pos_embed is None:
        model_pos_embed = getattr(model.text, 'positional_embedding', None)

    old_num_pos = old_pos_embed.shape[0]
    old_width = old_pos_embed.shape[1]
    num_pos = model_pos_embed.shape[0]
    width = model_pos_embed.shape[1]
    assert old_width == width, 'text pos_embed width changed!'
    if old_num_pos == num_pos:
        return

    logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
    old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
    old_pos_embed = F.interpolate(
        old_pos_embed,
        size=num_pos,
        mode=interpolation,
        antialias=antialias,
        align_corners=False,
    )
    old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
    new_pos_embed = old_pos_embed

    state_dict['positional_embedding'] = new_pos_embed


def get_model_preprocess_cfg(model):
    module = getattr(model, 'visual', model)
    preprocess_cfg = getattr(module, 'preprocess_cfg', {})
    if not preprocess_cfg:
        # use separate legacy attributes if preprocess_cfg dict not found
        size = getattr(module, 'image_size')
        if size is not None:
            preprocess_cfg['size'] = size
        mean = getattr(module, 'image_mean', None)
        if mean is not None:
            preprocess_cfg['mean'] = mean
        std = getattr(module, 'image_std', None)
        if std is not None:
            preprocess_cfg['std'] = std
    return preprocess_cfg


def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
    module = getattr(model, 'visual', model)
    module.image_mean = preprocess_cfg['mean']  # legacy attribute, keeping for bwd compat
    module.image_std = preprocess_cfg['std']  # legacy attribute, keeping for bwd compat
    module.preprocess_cfg = copy.deepcopy(preprocess_cfg)  # new attr, package all pp cfg as dict


def get_model_tokenize_cfg(model):
    module = getattr(model, 'text', model)
    cfg = {}
    context_length = getattr(module, 'context_length', None)
    if context_length is not None:
        cfg['context_length'] = context_length
    vocab_size = getattr(module, 'vocab_size', None)
    if vocab_size is not None:
        cfg['vocab_size'] = vocab_size
    return cfg

================================================
FILE: inf_clip/models/coca_arch.py
================================================
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
    LayerNormFp32,
    LayerNorm,
    QuickGELU,
    MultimodalTransformer,
)
from .clip_arch import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower

try:
    from transformers import (
        BeamSearchScorer,
        LogitsProcessorList,
        TopPLogitsWarper,
        TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
        MinLengthLogitsProcessor,
        MaxLengthCriteria,
        StopStringCriteria,
        EosTokenCriteria,
        StoppingCriteriaList
    )

    GENERATION_TYPES = {
        "top_k": TopKLogitsWarper,
        "top_p": TopPLogitsWarper,
        "beam_search": "beam_search"
    }
    _has_transformers = True
except ImportError as e:
    GENERATION_TYPES = {
        "top_k": None,
        "top_p": None,
        "beam_search": "beam_search"
    }
    _has_transformers = False


@dataclass
class MultimodalCfg(CLIPTextCfg):
    mlp_ratio: int = 4
    dim_head: int = 64
    heads: int = 8
    n_queries: int = 256
    attn_pooler_heads: int = 8


def _build_text_decoder_tower(
        embed_dim,
        multimodal_cfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
    act_layer = QuickGELU if quick_gelu else nn.GELU
    norm_layer = (
        LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
    )

    decoder = MultimodalTransformer(
        context_length=multimodal_cfg.context_length,
        width=multimodal_cfg.width,
        heads=multimodal_cfg.heads,
        layers=multimodal_cfg.layers,
        ls_init_value=multimodal_cfg.ls_init_value,
        output_dim=embed_dim,
        act_layer=act_layer,
        norm_layer=norm_layer,
    )

    return decoder


def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
    if not isinstance(token_id, torch.Tensor):
        if isinstance(token_id, int):
            token_id = [token_id]
        token_id = torch.tensor(token_id, device=device)
    return token_id


class CoCa(nn.Module):
    arch_type: torch.jit.Final[str] = 'coca'
    
    def __init__(
            self,
            embed_dim,
            multimodal_cfg: MultimodalCfg,
            text_cfg: CLIPTextCfg,
            vision_cfg: CLIPVisionCfg,
            quick_gelu: bool = False,
            init_logit_scale: float = np.log(1 / 0.07),
            init_logit_bias: Optional[float] = None,
            cast_dtype: Optional[torch.dtype] = None,
            pad_id: int = 0,
    ):
        super().__init__()
        multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
        text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
        vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

        self.text = _build_text_tower(
            embed_dim=embed_dim,
            text_cfg=text_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        vocab_size = (
            text_cfg.vocab_size  # for hf models
            if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
            else text_cfg.vocab_size
        )

        self.visual = _build_vision_tower(
            embed_dim=embed_dim,
            vision_cfg=vision_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        self.text_decoder = _build_text_decoder_tower(
            vocab_size,
            multimodal_cfg=multimodal_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        if init_logit_bias is not None:
            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
        else:
            self.logit_bias = None
        self.pad_id = pad_id

        self.context_length = multimodal_cfg.context_length

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable: bool = True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)
        self.text_decoder.set_grad_checkpointing(enable)

    def _encode_image(self, images, normalize: bool = True):
        image_latent, tokens_embs = self.visual(images)
        image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
        return image_latent, tokens_embs

    def _encode_text(self, text, normalize: bool = True):
        text_latent, token_emb = self.text(text)
        text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
        return text_latent, token_emb

    def encode_image(self, images, normalize: bool = True):
        image_latent, _ = self._encode_image(images, normalize=normalize)
        return image_latent

    def encode_text(self, text, normalize: bool = True):
        text_latent, _ = self._encode_text(text, normalize=normalize)
        return text_latent

    def forward(
            self,
            image,
            text: Optional[torch.Tensor] = None,
            image_latent: Optional[torch.Tensor] = None,
            image_embs: Optional[torch.Tensor] = None,
            output_labels: bool = True,
    ):
        if image_latent is None or image_embs is None:
            image_latent, image_embs = self._encode_image(image)

        if text is None:
            return {"image_features": image_latent, "image_embs": image_embs}

        text_latent, token_embs = self._encode_text(text)

        # FIXME this isn't an ideal solution, would like to improve -RW
        labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
        if output_labels:
            # align text_embs and thus logits with labels for teacher-forcing caption loss
            token_embs = token_embs[:, :-1]

        logits = self.text_decoder(image_embs, token_embs)
        out_dict = {
            "image_features": image_latent,
            "text_features": text_latent,
            "logits": logits,
            "logit_scale": self.logit_scale.exp()
        }
        if labels is not None:
            out_dict["labels"] = labels
        if self.logit_bias is not None:
            out_dict["logit_bias"] = self.logit_bias
        return out_dict

    def generate(
        self,
        image,
        text=None,
        seq_len=30,
        max_seq_len=77,
        temperature=1.,
        generation_type="beam_search",
        top_p=0.1,  # keep tokens in the 1 - top_p quantile
        top_k=1,  # keeps the top_k most probable tokens
        pad_token_id=None,
        eos_token_id=None,
        sot_token_id=None,
        num_beams=6,
        num_beam_groups=3,
        min_seq_len=5,
        stopping_criteria=None,
        repetition_penalty=1.0,
        fixed_output_length=False # if True output.shape == (batch_size, seq_len)
    ):
        # taking many ideas and components from HuggingFace GenerationMixin
        # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
        assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
        assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
        device = image.device

        with torch.no_grad():
            sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
            eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
            pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
            logit_processor = LogitsProcessorList(
                [
                    MinLengthLogitsProcessor(min_seq_len, eos_token_id),
                    RepetitionPenaltyLogitsProcessor(repetition_penalty),
                ]
            )

            if stopping_criteria is None:
                stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
            stopping_criteria = StoppingCriteriaList(stopping_criteria)

            if generation_type == "beam_search":
                output = self._generate_beamsearch(
                    image_inputs=image,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                    sot_token_id=sot_token_id,
                    num_beams=num_beams,
                    num_beam_groups=num_beam_groups,
                    min_seq_len=min_seq_len,
                    stopping_criteria=stopping_criteria,
                    logit_processor=logit_processor,
                )
                if fixed_output_length and output.shape[1] < seq_len:
                    pad_len = seq_len - output.shape[1]
                    return torch.cat((
                            output,
                            torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
                        ),
                        dim=1
                    )
                return output

            elif generation_type == "top_p":
                logit_warper = GENERATION_TYPES[generation_type](top_p)
            elif generation_type == "top_k":
                logit_warper = GENERATION_TYPES[generation_type](top_k)
            else:
                raise ValueError(
                    f"generation_type has to be one of "
                    f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
                )

            image_latent, image_embs = self._encode_image(image)

            if text is None:
                text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id

            was_training = self.training
            num_dims = len(text.shape)

            if num_dims == 1:
                text = text[None, :]

            self.eval()
            out = text

            while True:
                x = out[:, -max_seq_len:]
                cur_len = x.shape[1]
                logits = self(
                    image,
                    x,
                    image_latent=image_latent,
                    image_embs=image_embs,
                    output_labels=False,
                )["logits"][:, -1]
                mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
                sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

                if mask.all():
                    if not fixed_output_length:
                        break
                else:
                    logits = logits[~mask, :]
                    filtered_logits = logit_processor(x[~mask, :], logits)
                    filtered_logits = logit_warper(x[~mask, :], filtered_logits)
                    probs = F.softmax(filtered_logits / temperature, dim=-1)

                    if (cur_len + 1 == seq_len):
                        sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
                    else:
                        sample[~mask, :] = torch.multinomial(probs, 1)

                out = torch.cat((out, sample), dim=-1)

                cur_len += 1

                if all(stopping_criteria(out, None)):
                    break

            if num_dims == 1:
                out = out.squeeze(0)

            self.train(was_training)
            return out

    def _generate_beamsearch(
            self,
            image_inputs,
            pad_token_id=None,
            eos_token_id=None,
            sot_token_id=None,
            num_beams=6,
            num_beam_groups=3,
            min_seq_len=5,
            stopping_criteria=None,
            logit_processor=None,
            logit_warper=None,
    ):
        device = image_inputs.device
        batch_size = image_inputs.shape[0]
        image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
        image_latent, image_embs = self._encode_image(image_inputs)

        input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
        input_ids = input_ids * sot_token_id
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device=device,
            num_beam_groups=num_beam_groups,
        )
        # instantiate logits processors
        logits_processor = (
            LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
            if logit_processor is None
            else logit_processor
        )

        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
        batch_beam_size, cur_len = input_ids.shape
        beam_indices = None

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

        while True:

            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
            outputs = self(
                model_inputs['images'],
                model_inputs['text'],
                image_latent=image_latent,
                image_embs=image_embs,
                output_labels=False,
            )

            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []

                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of currentg group only
                next_token_logits = outputs['logits'][batch_group_indices, -1, :]
                vocab_size = next_token_logits.shape[-1]

                next_token_scores_processed = logits_processor(
                    group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

                # reshape for beam search
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

                next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
                next_tokens = next_tokens % vocab_size

                # stateless
                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                    beam_indices=process_beam_indices,
                    group_index=beam_group_idx,
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
                    num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
                )

            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

            # increase cur_len
            cur_len = cur_len + 1
            if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
                break

        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
        sequence_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=final_beam_indices,
        )
        return sequence_outputs['sequences']


def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
    if past:
        input_ids = input_ids[:, -1].unsqueeze(-1)

    attention_mask = kwargs.get("attention_mask", None)
    position_ids = kwargs.get("position_ids", None)

    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
    else:
        position_ids = None
    return {
        "text": input_ids,
        "images": image_inputs,
        "past_key_values": past,
        "position_ids": position_ids,
        "attention_mask": attention_mask,
    }


================================================
FILE: inf_clip/models/hf_configs.py
================================================
# HF architecture dict:
arch_dict = {
    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
    "roberta": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "hidden_size",
            "heads": "num_attention_heads",
            "layers": "num_hidden_layers",
            "layer_attr": "layer",
            "token_embeddings_attr": "embeddings"
        },
        "pooler": "mean_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
    "xlm-roberta": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "hidden_size",
            "heads": "num_attention_heads",
            "layers": "num_hidden_layers",
            "layer_attr": "layer",
            "token_embeddings_attr": "embeddings"
        },
        "pooler": "mean_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
    "mt5": {
        "config_names": {
            # unlimited seqlen
            # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
            # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
            "context_length": "",
            "vocab_size": "vocab_size",
            "width": "d_model",
            "heads": "num_heads",
            "layers": "num_layers",
            "layer_attr": "block",
            "token_embeddings_attr": "embed_tokens"
        },
        "pooler": "mean_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/bert
    "bert": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "hidden_size",
            "heads": "num_attention_heads",
            "layers": "num_hidden_layers",
        },
        "pooler": "cls_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/m2m_100
    "m2m_100": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "d_model",
            "heads": "encoder_attention_heads",
            "layers": "encoder_layers",
        },
        "pooler": "cls_pooler",
    },
}


================================================
FILE: inf_clip/models/hf_model.py
================================================
""" huggingface model adapter

Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
from contextlib import nullcontext

import torch
import torch.nn as nn
from torch import TensorType

try:
    import transformers
    from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
        BaseModelOutputWithPoolingAndCrossAttentions
    from transformers.modeling_utils import no_init_weights
except ImportError as e:
    transformers = None


    class BaseModelOutput:
        pass


    class PretrainedConfig:
        pass

from .hf_configs import arch_dict


# utils
def _camel2snake(s):
    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()


# TODO: ?last - for gpt-like models
_POOLERS = {}


def register_pooler(cls):
    """Decorator registering pooler class"""
    _POOLERS[_camel2snake(cls.__name__)] = cls
    return cls


@register_pooler
class MeanPooler(nn.Module):
    """Mean pooling"""

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)


@register_pooler
class MaxPooler(nn.Module):
    """Max pooling"""

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
        return masked_output.max(1).values


@register_pooler
class ClsPooler(nn.Module):
    """CLS token pooling"""

    def __init__(self, use_pooler_output=True):
        super().__init__()
        self.cls_token_position = 0
        self.use_pooler_output = use_pooler_output

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        if (self.use_pooler_output and
            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
            (x.pooler_output is not None)
        ):
            return x.pooler_output

        return x.last_hidden_state[:, self.cls_token_position, :]


@register_pooler
class ClsLastHiddenStatePooler(nn.Module):
    """CLS token pooling
    NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
    """

    def __init__(self):
        super().__init__()
        self.cls_token_position = 0

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
    """HuggingFace model adapter"""
    output_tokens: torch.jit.Final[bool]

    def __init__(
            self,
            model_name_or_path: str,
            output_dim: int,
            pooler_type: str = None,
            proj_type: str = None,
            pretrained: bool = True,
            output_tokens: bool = False,
    ):
        super().__init__()
        self.output_tokens = output_tokens
        self.output_dim = output_dim

        # TODO: find better way to get this information
        uses_transformer_pooler = (pooler_type == "cls_pooler")

        if transformers is None:
            raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")

        self.config = AutoConfig.from_pretrained(model_name_or_path)
        # FIXME: Gradient Accumulation can't fully resume the dropout state, so we close dropout here.
        # self.config.attention_probs_dropout_prob = 0.0  # Disable dropout
        # self.config.hidden_dropout_prob = 0.0  # Disable dropout
        # Enable sdpa attention
        self.config._attn_implementation = 'sdpa'
        # initialization of the model is really slow (https://github.com/huggingface/transformers/issues/9205#issuecomment-748741195)
        # FIXME: To speed up the initialization of the model, we only load pretrained weights and 
        # disable the torch initialization of the weights.
        if pretrained:
            context = no_init_weights
        else:
            context = nullcontext

        with context():
            # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
            if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
                self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True)
                self.transformer = self.transformer.encoder
            else:
                self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True, add_pooling_layer=uses_transformer_pooler)

        if pooler_type is None:  # get default arch pooler
            pooler_type = (arch_dict[self.config.model_type]["pooler"])

        # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
        self.vocab_size = getattr(self.config, 'vocab_size', 0)
        self.context_length = getattr(self.config, 'max_position_embeddings', 0)

        self.pooler = _POOLERS[pooler_type]()

        d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])

        if (d_model == output_dim) and (proj_type is None):  # do we always need a proj?
            self.proj = nn.Identity()
        elif proj_type == 'linear':
            self.proj = nn.Linear(d_model, output_dim, bias=False)
        elif proj_type == 'mlp':
            hidden_size = (d_model + output_dim) // 2
            self.proj = nn.Sequential(
                nn.Linear(d_model, hidden_size, bias=False),
                nn.GELU(),
                nn.Linear(hidden_size, output_dim, bias=False),
            )

    def forward(self, x: TensorType):
        attn_mask = (x != self.config.pad_token_id).long()
        out = self.transformer(input_ids=x, attention_mask=attn_mask)
        pooled_out = self.pooler(out, attn_mask)
        projected = self.proj(pooled_out)

        seq_len = out.last_hidden_state.shape[1]
        tokens = (
            out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] 
            if type(self.pooler) == ClsPooler 
            else out.last_hidden_state
        )
        
        if self.output_tokens:
            return projected, tokens
        return projected

    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        if not unlocked_layers:  # full freezing
            for n, p in self.transformer.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
            return

        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
        layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
        print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
        embeddings = getattr(
            self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
        modules = [embeddings, *layer_list][:-unlocked_layers]
        # freeze layers
        for module in modules:
            for n, p in module.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.gradient_checkpointing_enable()

    def init_parameters(self):
        pass


================================================
FILE: inf_clip/models/lit_arch.py
================================================
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as torch_checkpoint

from .clip_arch import _build_vision_tower, _build_text_tower


@dataclass
class LiTVisionCfg:
    layers: Union[Tuple[int, int, int, int], int] = 12
    width: int = 768
    head_width: int = 64
    mlp_ratio: float = 4.0
    patch_size: int = 16
    image_size: Union[Tuple[int, int], int] = 224

    ls_init_value: Optional[float] = None  # layer scale initial value
    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
    attentional_pool: bool = False  # whether to use attentional pooler in the last embedding layer (overrides pool_type)
    attn_pooler_queries: int = 256  # n_queries for attentional pooler
    attn_pooler_heads: int = 8  # n heads for attentional_pooling
    no_ln_pre: bool = False  # disable pre transformer LayerNorm
    pos_embed_type: str = 'learnable'
    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling
    pool_type: str = 'tok'
    output_tokens: bool = False
    act_kwargs: Optional[dict] = None
    norm_kwargs: Optional[dict] = None

    timm_model_name: Optional[str] = None  # a valid model name overrides layers, width, patch_size
    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')
    timm_proj_bias: bool = False  # enable bias final projection
    timm_drop: float = 0.  # head dropout
    timm_drop_path: Optional[float] = None  # backbone stochastic depth


@dataclass
class LiTTextCfg:
    context_length: int = 77
    vocab_size: int = 49408
    hf_tokenizer_name: Optional[str] = None
    tokenizer_kwargs: Optional[dict] = None

    width: int = 512
    heads: int = 8
    layers: int = 12
    mlp_ratio: float = 4.0
    ls_init_value: Optional[float] = None  # layer scale initial value
    embed_cls: bool = False
    pad_id: int = 0
    no_causal_mask: bool = False  # disable causal masking
    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling
    pool_type: str = 'argmax'
    proj_bias: bool = False
    output_tokens: bool = False
    act_kwargs: dict = None
    norm_kwargs: dict = None

    # HuggingFace specific text tower config
    hf_model_name: Optional[str] = None
    hf_model_pretrained: bool = True
    hf_proj_type: str = 'mlp'
    hf_pooler_type: str = 'mean_pooler'  # attentional pooling for HF models


class LiT(nn.Module):
    output_dict: torch.jit.Final[bool]
    arch_type: torch.jit.Final[str] = 'lit'

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: LiTVisionCfg,
            text_cfg: LiTTextCfg,
            quick_gelu: bool = False,
            init_logit_scale: float = np.log(1 / 0.07),
            init_logit_bias: Optional[float] = None,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict
        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.context_length = self.text.context_length
        self.vocab_size = self.text.vocab_size
        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
        if init_logit_bias is not None:
            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
        else:
            self.logit_bias = None

        self.embed_dim = embed_dim
        self.lock_image_tower()

    def get_embed_dim(self):
        return self.embed_dim

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        self.text.lock(unlocked_layers, freeze_layer_norm)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_trunk_image(self, image, normalize: bool = False):
        trunk_features = self.visual.forward_trunk(image)
        features = self.visual.forward_head(trunk_features)
        return trunk_features, F.normalize(features, dim=-1) if normalize else features

    def project_image(self, trunk_features, normalize: bool = False):
        features = self.visual.head(trunk_features)
        return trunk_features, F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        features = self.text(text)
        return F.normalize(features, dim=-1) if normalize else features

    def get_logits(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        image_logits = self.logit_scale.exp() * image_features @ text_features.T
        if self.logit_bias is not None:
            image_logits += self.logit_bias
        text_logits = image_logits.T
        return image_logits, text_logits

    def forward(
            self,
            image: Optional[torch.Tensor] = None,
            text: Optional[torch.Tensor] = None,
            project_only: Optional[bool] = False,
    ):
        if project_only:
            image_trunk_features, image_features = self.project_image(image, normalize=True) if image is not None else None
        else:
            image_trunk_features, image_features = self.encode_trunk_image(image, normalize=True) if image is not None else None
        text_features = self.encode_text(text, normalize=True) if text is not None else None

        if self.output_dict:
            out_dict = {
                "image_trunk_features": image_trunk_features,
                "image_features": image_features,
                "text_features": text_features,
                "logit_scale": self.logit_scale.exp()
            }
            if self.logit_bias is not None:
                out_dict['logit_bias'] = self.logit_bias
            return out_dict

        if self.logit_bias is not None:
            return image_trunk_features, image_features, text_features, self.logit_scale.exp(), self.logit_bias
        return image_trunk_features, image_features, text_features, self.logit_scale.exp()


================================================
FILE: inf_clip/models/loss.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F

try:
    import torch.distributed.nn
    from torch import distributed as dist

    has_distributed = True
except ImportError:
    has_distributed = False

try:
    import horovod.torch as hvd
except ImportError:
    hvd = None

from inf_cl import cal_flash_loss, cal_ring_loss, cal_inf_loss


def gather_features(
        image_features,
        text_features,
        local_loss=False,
        gather_with_grad=False,
        rank=0,
        world_size=1,
        use_horovod=False
):
    if world_size == 1:
        return image_features, text_features
    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
    if use_horovod:
        assert hvd is not None, 'Please install horovod'
        if gather_with_grad:
            all_image_features = hvd.allgather(image_features)
            all_text_features = hvd.allgather(text_features)
        else:
            with torch.no_grad():
                all_image_features = hvd.allgather(image_features)
                all_text_features = hvd.allgather(text_features)
            if not local_loss:
                # ensure grads for local rank when all_* features don't have a gradient
                gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
                gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
                gathered_image_features[rank] = image_features
                gathered_text_features[rank] = text_features
                all_image_features = torch.cat(gathered_image_features, dim=0)
                all_text_features = torch.cat(gathered_text_features, dim=0)
    else:
        # We gather tensors from all gpus
        if gather_with_grad:
            all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
            all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
        else:
            gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
            gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
            dist.all_gather(gathered_image_features, image_features)
            dist.all_gather(gathered_text_features, text_features)
            if not local_loss:
                # ensure grads for local rank when all_* features don't have a gradient
                gathered_image_features[rank] = image_features
                gathered_text_features[rank] = text_features
            all_image_features = torch.cat(gathered_image_features, dim=0)
            all_text_features = torch.cat(gathered_text_features, dim=0)

    return all_image_features, all_text_features


def all_reduce(tensor):
    if not dist.is_available():
        return tensor
    else:
        world_size = dist.get_world_size()
        dist.all_reduce(tensor)
        return tensor


class ClipLoss(nn.Module):

    def __init__(
            self,
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.world_size > 1 and self.local_loss:
                labels = labels + num_logits * self.rank
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        if self.world_size > 1:
            all_image_features, all_text_features = gather_features(
                image_features, text_features,
                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

            if self.local_loss:
                logits_per_image = logit_scale * image_features @ all_text_features.T
                logits_per_text = logit_scale * text_features @ all_image_features.T
            else:
                logits_per_image = logit_scale * all_image_features @ all_text_features.T
                logits_per_text = logits_per_image.T
        else:
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T
        
        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale):
        device = image_features.device
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels, reduction='none') +
            F.cross_entropy(logits_per_text, labels, reduction='none')
        ) / 2

        scale_factor = (total_loss.shape[0] / image_features.shape[0])
        total_loss = torch.mean(total_loss * scale_factor)

        show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)

        return {"contrastive_loss": total_loss, "show_loss": show_loss}


class DiscoClipLoss(nn.Module):

    def __init__(
            self,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        if self.prev_num_logits != num_logits or device not in self.labels:
            la
Download .txt
gitextract_u47kktxp/

├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── inf_cl/
│   ├── __init__.py
│   ├── flash.py
│   └── ring.py
├── inf_clip/
│   ├── __init__.py
│   ├── constants.py
│   ├── factory.py
│   ├── model_configs/
│   │   ├── EVA01-g-14-plus.json
│   │   ├── EVA01-g-14.json
│   │   ├── EVA02-B-16.json
│   │   ├── EVA02-E-14-plus.json
│   │   ├── EVA02-E-14.json
│   │   ├── EVA02-L-14-336.json
│   │   ├── EVA02-L-14.json
│   │   ├── LiT-B-16.json
│   │   ├── LiT-B-32.json
│   │   ├── LiT-L-16.json
│   │   ├── MobileCLIP-B.json
│   │   ├── MobileCLIP-S1.json
│   │   ├── MobileCLIP-S2.json
│   │   ├── RN101-quickgelu.json
│   │   ├── RN101.json
│   │   ├── RN50-quickgelu.json
│   │   ├── RN50.json
│   │   ├── RN50x16.json
│   │   ├── RN50x4.json
│   │   ├── RN50x64.json
│   │   ├── ViT-B-16-SigLIP-256.json
│   │   ├── ViT-B-16-SigLIP-384.json
│   │   ├── ViT-B-16-SigLIP-512.json
│   │   ├── ViT-B-16-SigLIP-i18n-256.json
│   │   ├── ViT-B-16-SigLIP.json
│   │   ├── ViT-B-16-plus-240.json
│   │   ├── ViT-B-16-plus.json
│   │   ├── ViT-B-16-quickgelu.json
│   │   ├── ViT-B-16.json
│   │   ├── ViT-B-32-256.json
│   │   ├── ViT-B-32-plus-256.json
│   │   ├── ViT-B-32-quickgelu.json
│   │   ├── ViT-B-32.json
│   │   ├── ViT-H-14-378-quickgelu.json
│   │   ├── ViT-H-14-CLIPA-336.json
│   │   ├── ViT-H-14-CLIPA.json
│   │   ├── ViT-H-14-quickgelu.json
│   │   ├── ViT-H-14.json
│   │   ├── ViT-H-16.json
│   │   ├── ViT-L-14-280.json
│   │   ├── ViT-L-14-336.json
│   │   ├── ViT-L-14-CLIPA-336.json
│   │   ├── ViT-L-14-CLIPA.json
│   │   ├── ViT-L-14-quickgelu.json
│   │   ├── ViT-L-14.json
│   │   ├── ViT-L-16-320.json
│   │   ├── ViT-L-16-SigLIP-256.json
│   │   ├── ViT-L-16-SigLIP-384.json
│   │   ├── ViT-L-16.json
│   │   ├── ViT-M-16-alt.json
│   │   ├── ViT-M-16.json
│   │   ├── ViT-M-32-alt.json
│   │   ├── ViT-M-32.json
│   │   ├── ViT-S-16-alt.json
│   │   ├── ViT-S-16.json
│   │   ├── ViT-S-32-alt.json
│   │   ├── ViT-S-32.json
│   │   ├── ViT-SO400M-14-SigLIP-384.json
│   │   ├── ViT-SO400M-14-SigLIP.json
│   │   ├── ViT-bigG-14-CLIPA-336.json
│   │   ├── ViT-bigG-14-CLIPA.json
│   │   ├── ViT-bigG-14.json
│   │   ├── ViT-e-14.json
│   │   ├── ViT-g-14.json
│   │   ├── ViTamin-B-LTT.json
│   │   ├── ViTamin-B.json
│   │   ├── ViTamin-L-256.json
│   │   ├── ViTamin-L-336.json
│   │   ├── ViTamin-L.json
│   │   ├── ViTamin-L2-256.json
│   │   ├── ViTamin-L2-336.json
│   │   ├── ViTamin-L2.json
│   │   ├── ViTamin-S-LTT.json
│   │   ├── ViTamin-S.json
│   │   ├── ViTamin-XL-256.json
│   │   ├── ViTamin-XL-336.json
│   │   ├── ViTamin-XL-384.json
│   │   ├── coca_ViT-B-32.json
│   │   ├── coca_ViT-L-14.json
│   │   ├── coca_base.json
│   │   ├── coca_roberta-ViT-B-32.json
│   │   ├── convnext_base.json
│   │   ├── convnext_base_w.json
│   │   ├── convnext_base_w_320.json
│   │   ├── convnext_large.json
│   │   ├── convnext_large_d.json
│   │   ├── convnext_large_d_320.json
│   │   ├── convnext_small.json
│   │   ├── convnext_tiny.json
│   │   ├── convnext_xlarge.json
│   │   ├── convnext_xxlarge.json
│   │   ├── convnext_xxlarge_320.json
│   │   ├── mt5-base-ViT-B-32.json
│   │   ├── mt5-xl-ViT-H-14.json
│   │   ├── nllb-clip-base-siglip.json
│   │   ├── nllb-clip-base.json
│   │   ├── nllb-clip-large-siglip.json
│   │   ├── nllb-clip-large.json
│   │   ├── roberta-ViT-B-32.json
│   │   ├── swin_base_patch4_window7_224.json
│   │   ├── vit_medium_patch16_gap_256.json
│   │   ├── vit_relpos_medium_patch16_cls_224.json
│   │   ├── xlm-roberta-base-ViT-B-32.json
│   │   └── xlm-roberta-large-ViT-H-14.json
│   ├── models/
│   │   ├── clip_arch.py
│   │   ├── coca_arch.py
│   │   ├── hf_configs.py
│   │   ├── hf_model.py
│   │   ├── lit_arch.py
│   │   ├── loss.py
│   │   ├── modified_resnet.py
│   │   ├── pos_embed.py
│   │   ├── timm_model.py
│   │   ├── tokenizer.py
│   │   ├── transform.py
│   │   └── transformer.py
│   ├── openai.py
│   ├── pretrained.py
│   ├── train/
│   │   ├── data.py
│   │   ├── engine.py
│   │   ├── main.py
│   │   ├── optims.py
│   │   ├── params.py
│   │   └── utils.py
│   ├── utils.py
│   ├── zero_shot_classifier.py
│   └── zero_shot_metadata.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│   ├── benchmarks_eval.sh
│   ├── cc12m/
│   │   ├── clip_vit-b-32_bs32k.sh
│   │   ├── lit_vit-b-16_bs32k.sh
│   │   └── lit_vit-b-32_bs32k.sh
│   ├── cc3m/
│   │   ├── clip_r50_bs4k.sh
│   │   ├── clip_vit-b-32_bs16k.sh
│   │   └── lit_vit-b-32_bs16k.sh
│   ├── imagenet_eval.sh
│   └── laion400m/
│       ├── clip_vit-b-32_bs256k.sh
│       ├── lit_vit-b-16_bs256k.sh
│       ├── lit_vit-b-32_bs256k.sh
│       └── lit_vit-l-16_bs256k.sh
└── tests/
    └── example.py
Download .txt
SYMBOL INDEX (447 symbols across 25 files)

FILE: inf_cl/flash.py
  function _prob_fwd_kernel (line 12) | def _prob_fwd_kernel(
  function _dq_prob_bwd_kernel (line 70) | def _dq_prob_bwd_kernel(
  function _dk_prob_bwd_kernel (line 140) | def _dk_prob_bwd_kernel(
  function _flash_prob_forward (line 204) | def _flash_prob_forward(q, k):
  function _flash_prob_backward (line 242) | def _flash_prob_backward(q, k, lse, dlse):
  class FlashProb (line 306) | class FlashProb(torch.autograd.Function):
    method forward (line 309) | def forward(ctx, q, k):
    method backward (line 316) | def backward(ctx, dlse):
  function _cal_flash_loss (line 323) | def _cal_flash_loss(q, k, labels, head_dim=256):
  function cal_flash_loss (line 337) | def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):

FILE: inf_cl/ring.py
  class RingComm (line 17) | class RingComm:
    method __init__ (line 19) | def __init__(self, process_group: dist.ProcessGroup):
    method send_recv (line 33) | def send_recv(self, to_send, recv_tensor = None):
    method commit (line 45) | def commit(self):
    method wait (line 50) | def wait(self):
  class GradientGather (line 59) | class GradientGather(torch.autograd.Function):
    method forward (line 62) | def forward(ctx, x):
    method backward (line 67) | def backward(ctx, dx):
  class RingProb (line 72) | class RingProb(torch.autograd.Function):
    method forward (line 75) | def forward(ctx, q, k, group):
    method backward (line 109) | def backward(ctx, dlse):
  class InfProb (line 154) | class InfProb(torch.autograd.Function):
    method forward (line 157) | def forward(ctx, q, k, group):
    method backward (line 190) | def backward(ctx, dlse):
  function set_seed (line 231) | def set_seed(rank, seed=42):
  function _cal_ring_loss (line 239) | def _cal_ring_loss(q, k, labels, head_dim=256):
  function _cal_inf_loss (line 252) | def _cal_inf_loss(q, k, labels, head_dim=256):
  function cal_ring_loss (line 265) | def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):
  function cal_inf_loss (line 289) | def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):

FILE: inf_clip/factory.py
  function _natural_key (line 32) | def _natural_key(string_):
  function _rescan_model_configs (line 36) | def _rescan_model_configs():
  function list_models (line 60) | def list_models():
  function add_model_config (line 65) | def add_model_config(path):
  function get_model_config (line 73) | def get_model_config(model_name):
  function _get_hf_config (line 80) | def _get_hf_config(model_id, cache_dir=None):
  function get_tokenizer (line 87) | def get_tokenizer(
  function load_state_dict (line 131) | def load_state_dict(checkpoint_path: str, map_location='cpu'):
  function load_checkpoint (line 146) | def load_checkpoint(
  function create_model (line 183) | def create_model(
  function create_loss (line 349) | def create_loss(args):
  function create_model_and_transforms (line 410) | def create_model_and_transforms(
  function create_model_from_pretrained (line 467) | def create_model_from_pretrained(

FILE: inf_clip/models/clip_arch.py
  class CLIPVisionCfg (line 27) | class CLIPVisionCfg:
  class CLIPTextCfg (line 58) | class CLIPTextCfg:
  function get_cast_dtype (line 86) | def get_cast_dtype(precision: str):
  function get_input_dtype (line 95) | def get_input_dtype(precision: str):
  function _build_vision_tower (line 104) | def _build_vision_tower(
  function _build_text_tower (line 173) | def _build_text_tower(
  class CLIP (line 220) | class CLIP(nn.Module):
    method __init__ (line 224) | def __init__(
    method lock_image_tower (line 257) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 262) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 266) | def encode_image(self, image, normalize: bool = False):
    method encode_text (line 270) | def encode_text(self, text, normalize: bool = False):
    method get_logits (line 287) | def get_logits(self, image, text):
    method forward (line 296) | def forward(
  class CustomTextCLIP (line 319) | class CustomTextCLIP(nn.Module):
    method __init__ (line 323) | def __init__(
    method lock_image_tower (line 346) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method lock_text_tower (line 350) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
    method set_grad_checkpointing (line 354) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 358) | def encode_image(self, image, normalize: bool = False):
    method encode_text (line 362) | def encode_text(self, text, normalize: bool = False):
    method get_logits (line 366) | def get_logits(self, image, text):
    method forward (line 375) | def forward(
  function convert_weights_to_lp (line 398) | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
  function convert_to_custom_text_state_dict (line 432) | def convert_to_custom_text_state_dict(state_dict: dict):
  function build_model_from_openai_state_dict (line 450) | def build_model_from_openai_state_dict(
  function trace_model (line 509) | def trace_model(model, batch_size=256, device=torch.device('cpu')):
  function resize_pos_embed (line 525) | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', ...
  function resize_text_pos_embed (line 559) | def resize_text_pos_embed(state_dict, model, interpolation: str = 'linea...
  function get_model_preprocess_cfg (line 591) | def get_model_preprocess_cfg(model):
  function set_model_preprocess_cfg (line 608) | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
  function get_model_tokenize_cfg (line 615) | def get_model_tokenize_cfg(model):

FILE: inf_clip/models/coca_arch.py
  class MultimodalCfg (line 47) | class MultimodalCfg(CLIPTextCfg):
  function _build_text_decoder_tower (line 55) | def _build_text_decoder_tower(
  function _token_to_tensor (line 81) | def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
  class CoCa (line 89) | class CoCa(nn.Module):
    method __init__ (line 92) | def __init__(
    method set_grad_checkpointing (line 146) | def set_grad_checkpointing(self, enable: bool = True):
    method _encode_image (line 151) | def _encode_image(self, images, normalize: bool = True):
    method _encode_text (line 156) | def _encode_text(self, text, normalize: bool = True):
    method encode_image (line 161) | def encode_image(self, images, normalize: bool = True):
    method encode_text (line 165) | def encode_text(self, text, normalize: bool = True):
    method forward (line 169) | def forward(
    method generate (line 204) | def generate(
    method _generate_beamsearch (line 331) | def _generate_beamsearch(
  function prepare_inputs_for_generation (line 481) | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **...

FILE: inf_clip/models/hf_model.py
  class BaseModelOutput (line 22) | class BaseModelOutput:
  class PretrainedConfig (line 26) | class PretrainedConfig:
  function _camel2snake (line 33) | def _camel2snake(s):
  function register_pooler (line 41) | def register_pooler(cls):
  class MeanPooler (line 48) | class MeanPooler(nn.Module):
    method forward (line 51) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class MaxPooler (line 57) | class MaxPooler(nn.Module):
    method forward (line 60) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class ClsPooler (line 66) | class ClsPooler(nn.Module):
    method __init__ (line 69) | def __init__(self, use_pooler_output=True):
    method forward (line 74) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class ClsLastHiddenStatePooler (line 85) | class ClsLastHiddenStatePooler(nn.Module):
    method __init__ (line 90) | def __init__(self):
    method forward (line 94) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class HFTextEncoder (line 98) | class HFTextEncoder(nn.Module):
    method __init__ (line 102) | def __init__(
    method forward (line 166) | def forward(self, x: TensorType):
    method lock (line 183) | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
    method set_grad_checkpointing (line 201) | def set_grad_checkpointing(self, enable=True):
    method init_parameters (line 204) | def init_parameters(self):

FILE: inf_clip/models/lit_arch.py
  class LiTVisionCfg (line 14) | class LiTVisionCfg:
  class LiTTextCfg (line 45) | class LiTTextCfg:
  class LiT (line 73) | class LiT(nn.Module):
    method __init__ (line 77) | def __init__(
    method get_embed_dim (line 103) | def get_embed_dim(self):
    method lock_image_tower (line 106) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method lock_text_tower (line 110) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
    method set_grad_checkpointing (line 114) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 118) | def encode_image(self, image, normalize: bool = False):
    method encode_trunk_image (line 122) | def encode_trunk_image(self, image, normalize: bool = False):
    method project_image (line 127) | def project_image(self, trunk_features, normalize: bool = False):
    method encode_text (line 131) | def encode_text(self, text, normalize: bool = False):
    method get_logits (line 135) | def get_logits(self, image, text):
    method forward (line 144) | def forward(

FILE: inf_clip/models/loss.py
  function gather_features (line 21) | def gather_features(
  function all_reduce (line 70) | def all_reduce(tensor):
  class ClipLoss (line 79) | class ClipLoss(nn.Module):
    method __init__ (line 81) | def __init__(
    method get_ground_truth (line 102) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
    method get_logits (line 115) | def get_logits(self, image_features, text_features, logit_scale):
    method forward (line 133) | def forward(self, image_features, text_features, logit_scale):
  class DiscoClipLoss (line 152) | class DiscoClipLoss(nn.Module):
    method __init__ (line 154) | def __init__(
    method get_ground_truth (line 169) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
    method forward (line 180) | def forward(self, image_features, text_features, logit_scale):
  class FlashClipLoss (line 204) | class FlashClipLoss(nn.Module):
    method __init__ (line 206) | def __init__(
    method get_ground_truth (line 221) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
    method forward (line 231) | def forward(self, image_features, text_features, logit_scale):
  class RingClipLoss (line 251) | class RingClipLoss(nn.Module):
    method __init__ (line 253) | def __init__(
    method forward (line 267) | def forward(self, image_features, text_features, logit_scale):
  class InfClipLoss (line 291) | class InfClipLoss(nn.Module):
    method __init__ (line 293) | def __init__(
    method forward (line 307) | def forward(self, image_features, text_features, logit_scale):
  class CoCaLoss (line 384) | class CoCaLoss(ClipLoss):
    method __init__ (line 385) | def __init__(
    method forward (line 410) | def forward(self, image_features, text_features, logits, labels, logit...
  class DistillClipLoss (line 430) | class DistillClipLoss(ClipLoss):
    method dist_loss (line 432) | def dist_loss(self, teacher_logits, student_logits):
    method forward (line 435) | def forward(
  function neighbour_exchange (line 469) | def neighbour_exchange(from_rank, to_rank, tensor, group=None):
  function neighbour_exchange_bidir (line 489) | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tens...
  class NeighbourExchange (line 522) | class NeighbourExchange(torch.autograd.Function):
    method forward (line 524) | def forward(ctx, from_rank, to_rank, group, tensor):
    method backward (line 531) | def backward(ctx, grad_output):
  function neighbour_exchange_with_grad (line 535) | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
  class NeighbourExchangeBidir (line 539) | class NeighbourExchangeBidir(torch.autograd.Function):
    method forward (line 541) | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_...
    method backward (line 548) | def backward(ctx, *grad_outputs):
  function neighbour_exchange_bidir_with_grad (line 553) | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_...
  class SigLipLoss (line 557) | class SigLipLoss(nn.Module):
    method __init__ (line 567) | def __init__(
    method get_ground_truth (line 587) | def get_ground_truth(self, device, dtype, num_logits, negative_only=Fa...
    method get_logits (line 593) | def get_logits(self, image_features, text_features, logit_scale, logit...
    method _loss (line 599) | def _loss(self, image_features, text_features, logit_scale, logit_bias...
    method forward (line 610) | def forward(self, image_features, text_features, logit_scale, logit_bi...

FILE: inf_clip/models/modified_resnet.py
  class Bottleneck (line 10) | class Bottleneck(nn.Module):
    method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 42) | def forward(self, x: torch.Tensor):
  class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module):
    method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
    method forward (line 68) | def forward(self, x):
  class ModifiedResNet (line 95) | class ModifiedResNet(nn.Module):
    method __init__ (line 103) | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
    method _make_layer (line 132) | def _make_layer(self, planes, blocks, stride=1):
    method init_parameters (line 141) | def init_parameters(self):
    method lock (line 154) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 162) | def set_grad_checkpointing(self, enable=True):
    method stem (line 166) | def stem(self, x):
    method forward (line 173) | def forward(self, x):

FILE: inf_clip/models/pos_embed.py
  function get_2d_sincos_pos_embed (line 20) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
  function get_2d_sincos_pos_embed_from_grid (line 38) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  function interpolate_pos_embed (line 75) | def interpolate_pos_embed(model, checkpoint_model):

FILE: inf_clip/models/timm_model.py
  class TimmModel (line 28) | class TimmModel(nn.Module):
    method __init__ (line 32) | def __init__(
    method lock (line 110) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 143) | def set_grad_checkpointing(self, enable=True):
    method forward_trunk (line 149) | def forward_trunk(self, x):
    method forward_head (line 152) | def forward_head(self, x):
    method forward (line 155) | def forward(self, x):

FILE: inf_clip/models/tokenizer.py
  function default_bpe (line 27) | def default_bpe():
  function bytes_to_unicode (line 32) | def bytes_to_unicode():
  function get_pairs (line 54) | def get_pairs(word):
  function basic_clean (line 66) | def basic_clean(text):
  function whitespace_clean (line 72) | def whitespace_clean(text):
  function _clean_canonicalize (line 78) | def _clean_canonicalize(x):
  function _clean_lower (line 83) | def _clean_lower(x):
  function _clean_whitespace (line 88) | def _clean_whitespace(x):
  function get_clean_fn (line 93) | def get_clean_fn(type: str):
  function canonicalize_text (line 104) | def canonicalize_text(
  class SimpleTokenizer (line 133) | class SimpleTokenizer(object):
    method __init__ (line 134) | def __init__(
    method bpe (line 172) | def bpe(self, token):
    method encode (line 213) | def encode(self, text):
    method decode (line 221) | def decode(self, tokens):
    method __call__ (line 226) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...
  function decode (line 271) | def decode(output_ids: torch.Tensor):
  function tokenize (line 276) | def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT...
  function random_mask_tokenize (line 280) | def random_mask_tokenize(
  function simple_mask_tokenize (line 309) | def simple_mask_tokenize(
  function syntax_mask_tokenize (line 331) | def syntax_mask_tokenize(
  function get_reduction_mask_fn (line 390) | def get_reduction_mask_fn(type: str):
  class HFTokenizer (line 403) | class HFTokenizer:
    method __init__ (line 406) | def __init__(
    method save_pretrained (line 426) | def save_pretrained(self, dest):
    method __call__ (line 429) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...
    method set_language (line 456) | def set_language(self, src_lang):
  class SigLipTokenizer (line 463) | class SigLipTokenizer:
    method __init__ (line 473) | def __init__(
    method save_pretrained (line 497) | def save_pretrained(self, dest):
    method __call__ (line 500) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...

FILE: inf_clip/models/transform.py
  class PreprocessCfg (line 17) | class PreprocessCfg:
    method __post_init__ (line 26) | def __post_init__(self):
    method num_channels (line 30) | def num_channels(self):
    method input_size (line 34) | def input_size(self):
  function merge_preprocess_dict (line 40) | def merge_preprocess_dict(
  function merge_preprocess_kwargs (line 57) | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
  class AugmentationCfg (line 62) | class AugmentationCfg:
  function _setup_size (line 75) | def _setup_size(size, error_msg):
  class ResizeKeepRatio (line 88) | class ResizeKeepRatio:
    method __init__ (line 94) | def __init__(
    method get_params (line 116) | def get_params(
    method __call__ (line 144) | def __call__(self, img):
    method __repr__ (line 160) | def __repr__(self):
  function center_crop_or_pad (line 167) | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0...
  class CenterCropOrPad (line 207) | class CenterCropOrPad(torch.nn.Module):
    method __init__ (line 219) | def __init__(self, size, fill=0):
    method forward (line 224) | def forward(self, img):
    method __repr__ (line 234) | def __repr__(self) -> str:
  function _convert_to_rgb (line 238) | def _convert_to_rgb(image):
  class color_jitter (line 242) | class color_jitter(object):
    method __init__ (line 246) | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., ...
    method __call__ (line 251) | def __call__(self, img):
  class gray_scale (line 258) | class gray_scale(object):
    method __init__ (line 262) | def __init__(self, p=0.2):
    method __call__ (line 267) | def __call__(self, img):
  function image_transform (line 274) | def image_transform(
  function image_transform_v2 (line 393) | def image_transform_v2(

FILE: inf_clip/models/transformer.py
  class LayerNormFp32 (line 15) | class LayerNormFp32(nn.LayerNorm):
    method forward (line 18) | def forward(self, x: torch.Tensor):
  class LayerNorm (line 24) | class LayerNorm(nn.LayerNorm):
    method forward (line 27) | def forward(self, x: torch.Tensor):
  class QuickGELU (line 33) | class QuickGELU(nn.Module):
    method forward (line 35) | def forward(self, x: torch.Tensor):
  class LayerScale (line 39) | class LayerScale(nn.Module):
    method __init__ (line 40) | def __init__(self, dim, init_values=1e-5, inplace=False):
    method forward (line 45) | def forward(self, x):
  class PatchDropout (line 49) | class PatchDropout(nn.Module):
    method __init__ (line 54) | def __init__(self, prob, exclude_first_token=True):
    method forward (line 60) | def forward(self, x):
  class Attention (line 89) | class Attention(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 132) | def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  class AttentionalPooler (line 187) | class AttentionalPooler(nn.Module):
    method __init__ (line 188) | def __init__(
    method forward (line 202) | def forward(self, x: torch.Tensor):
  class ResidualAttentionBlock (line 210) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 211) | def __init__(
    method attention (line 239) | def attention(
    method forward (line 254) | def forward(
  class CustomResidualAttentionBlock (line 268) | class CustomResidualAttentionBlock(nn.Module):
    method __init__ (line 269) | def __init__(
    method get_reference_weight (line 306) | def get_reference_weight(self):
    method forward (line 309) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  function _expand_token (line 315) | def _expand_token(token, batch_size: int):
  class Transformer (line 319) | class Transformer(nn.Module):
    method __init__ (line 320) | def __init__(
    method get_cast_dtype (line 350) | def get_cast_dtype(self) -> torch.dtype:
    method forward (line 355) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  class CustomTransformer (line 369) | class CustomTransformer(nn.Module):
    method __init__ (line 371) | def __init__(
    method get_cast_dtype (line 412) | def get_cast_dtype(self) -> torch.dtype:
    method forward (line 418) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  class VisionTransformer (line 434) | class VisionTransformer(nn.Module):
    method __init__ (line 437) | def __init__(
    method lock (line 541) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method init_parameters (line 574) | def init_parameters(self):
    method set_grad_checkpointing (line 595) | def set_grad_checkpointing(self, enable=True):
    method _global_pool (line 598) | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.T...
    method forward (line 608) | def forward(self, x: torch.Tensor):
  function text_global_pool (line 653) | def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: ...
  class TextTransformer (line 668) | class TextTransformer(nn.Module):
    method __init__ (line 671) | def __init__(
    method init_parameters (line 731) | def init_parameters(self):
    method set_grad_checkpointing (line 755) | def set_grad_checkpointing(self, enable=True):
    method build_causal_mask (line 758) | def build_causal_mask(self):
    method build_cls_mask (line 766) | def build_cls_mask(self, text, cast_dtype: torch.dtype):
    method forward (line 775) | def forward(self, text):
  class MultimodalTransformer (line 812) | class MultimodalTransformer(Transformer):
    method __init__ (line 813) | def __init__(
    method init_parameters (line 856) | def init_parameters(self):
    method build_attention_mask (line 874) | def build_attention_mask(self):
    method forward (line 882) | def forward(self, image_embs, text_embs):
    method set_grad_checkpointing (line 907) | def set_grad_checkpointing(self, enable=True):

FILE: inf_clip/openai.py
  function list_openai_models (line 20) | def list_openai_models() -> List[str]:
  function load_openai_model (line 25) | def load_openai_model(

FILE: inf_clip/pretrained.py
  function _pcfg (line 34) | def _pcfg(url='', hf_hub='', **kwargs):
  function _slpcfg (line 47) | def _slpcfg(url='', hf_hub='', **kwargs):
  function _apcfg (line 60) | def _apcfg(url='', hf_hub='', **kwargs):
  function _mccfg (line 73) | def _mccfg(url='', hf_hub='', **kwargs):
  function _clean_tag (line 519) | def _clean_tag(tag: str):
  function list_pretrained (line 524) | def list_pretrained(as_str: bool = False):
  function list_pretrained_models_by_tag (line 531) | def list_pretrained_models_by_tag(tag: str):
  function list_pretrained_tags_by_model (line 541) | def list_pretrained_tags_by_model(model: str):
  function is_pretrained_cfg (line 549) | def is_pretrained_cfg(model: str, tag: str):
  function get_pretrained_cfg (line 555) | def get_pretrained_cfg(model: str, tag: str):
  function get_pretrained_url (line 562) | def get_pretrained_url(model: str, tag: str):
  function download_pretrained_from_url (line 567) | def download_pretrained_from_url(
  function has_hf_hub (line 613) | def has_hf_hub(necessary=False):
  function download_pretrained_from_hf (line 621) | def download_pretrained_from_hf(
  function download_pretrained (line 632) | def download_pretrained(
  function load_big_vision_weights (line 664) | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
  function convert_mobile_clip_state_dict (line 793) | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fa...
  function convert_state_dict (line 834) | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):

FILE: inf_clip/train/data.py
  class CsvDataset (line 24) | class CsvDataset(Dataset):
    method __init__ (line 25) | def __init__(self, input_filename, transforms, img_key, caption_key, s...
    method __len__ (line 36) | def __len__(self):
    method __getitem__ (line 39) | def __getitem__(self, idx):
  class SharedEpoch (line 45) | class SharedEpoch:
    method __init__ (line 46) | def __init__(self, epoch: int = 0):
    method set_value (line 49) | def set_value(self, epoch):
    method get_value (line 52) | def get_value(self):
  class DataInfo (line 57) | class DataInfo:
    method set_epoch (line 62) | def set_epoch(self, epoch):
  function expand_urls (line 69) | def expand_urls(urls, weights=None):
  function get_dataset_size (line 91) | def get_dataset_size(shards):
  function get_imagenet (line 113) | def get_imagenet(args, preprocess_fns, split):
  function count_samples (line 160) | def count_samples(dataloader):
  function filter_no_caption_or_no_image (line 170) | def filter_no_caption_or_no_image(sample):
  function log_and_continue (line 176) | def log_and_continue(exn):
  function group_by_keys_nothrow (line 182) | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes...
  function tarfile_to_samples_nothrow (line 214) | def tarfile_to_samples_nothrow(src, handler=log_and_continue):
  function pytorch_worker_seed (line 222) | def pytorch_worker_seed(increment=0):
  function json_fetch (line 236) | def json_fetch(data, key='caption'):
  class detshuffle2 (line 262) | class detshuffle2(wds.PipelineStage):
    method __init__ (line 263) | def __init__(
    method run (line 275) | def run(self, src):
  class ResampledShards2 (line 294) | class ResampledShards2(IterableDataset):
    method __init__ (line 297) | def __init__(
    method __iter__ (line 324) | def __iter__(self):
  function get_wds_dataset (line 348) | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False...
  function get_csv_dataset (line 468) | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=No...
  class SyntheticDataset (line 498) | class SyntheticDataset(Dataset):
    method __init__ (line 500) | def __init__(
    method __len__ (line 516) | def __len__(self):
    method __getitem__ (line 519) | def __getitem__(self, idx):
  function get_synthetic_dataset (line 525) | def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokeni...
  function get_dataset_fn (line 548) | def get_dataset_fn(data_path, dataset_type):
  function get_data (line 568) | def get_data(args, preprocess_fns, epoch=0, tokenizer=None):

FILE: inf_clip/train/engine.py
  function accuracy (line 28) | def accuracy(output, target, topk=(1,)):
  function get_clip_metrics (line 34) | def get_clip_metrics(image_features, text_features, logit_scale):
  function maybe_compute_generative_loss (line 54) | def maybe_compute_generative_loss(model_out):
  function get_memory (line 61) | def get_memory():
  function seconds_to_hms (line 70) | def seconds_to_hms(seconds):
  function cal_grad_norm (line 77) | def cal_grad_norm(model):
  function assign_learning_rate (line 87) | def assign_learning_rate(optimizer, new_lr):
  function _warmup_lr (line 92) | def _warmup_lr(base_lr, warmup_length, step):
  function const_lr (line 96) | def const_lr(optimizer, base_lr, warmup_length, steps):
  function const_lr_cooldown (line 107) | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown...
  function cosine_lr (line 126) | def cosine_lr(optimizer, base_lr, warmup_length, steps):
  function postprocess_clip_output (line 139) | def postprocess_clip_output(model_out):
  function unwrap_model (line 147) | def unwrap_model(model):
  function backward (line 154) | def backward(total_loss, scaler):
  class AverageMeter (line 161) | class AverageMeter(object):
    method __init__ (line 164) | def __init__(self):
    method reset (line 167) | def reset(self):
    method update (line 173) | def update(self, val, n=1):
  class GradientAccum (line 180) | class GradientAccum:
    method __init__ (line 182) | def __init__(self, model, loss, scaler, autocast, input_dtype, device):
    method clear (line 203) | def clear(self):
    method clear_state (line 208) | def clear_state(self):
    method accum_inference (line 216) | def accum_inference(self, images, texts):
    method accum_forward_backward (line 246) | def accum_forward_backward(self):
  class GradientCache (line 292) | class GradientCache:
    method __init__ (line 294) | def __init__(self, model, loss, scaler, autocast, input_dtype, device):
    method clear (line 315) | def clear(self):
    method clear_state (line 320) | def clear_state(self):
    method forward_backward (line 327) | def forward_backward(self, images, texts):
    method accum_inference (line 345) | def accum_inference(self, images, texts):
    method accum_forward_backward (line 376) | def accum_forward_backward(self):
  function train_one_epoch (line 432) | def train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer...
  function evaluate (line 573) | def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
  function zero_shot_run (line 682) | def zero_shot_run(model, classifier, dataloader, args):
  function zero_shot_eval (line 709) | def zero_shot_eval(model, data, epoch, args, tokenizer=None):

FILE: inf_clip/train/main.py
  function random_seed (line 42) | def random_seed(seed=42, rank=0):
  function natural_key (line 50) | def natural_key(string_):
  function copy_codebase (line 55) | def copy_codebase(args):
  function prepare_logging (line 72) | def prepare_logging(args):
  function get_latest_checkpoint (line 127) | def get_latest_checkpoint(path: str, remote : bool):
  function prepare_resuming (line 143) | def prepare_resuming(args):
  function prepare_remote_sync (line 178) | def prepare_remote_sync(args):
  function prepare_model (line 205) | def prepare_model(args, device):
  function prepare_optimizer_scaler (line 308) | def prepare_optimizer_scaler(args, model):
  function prepare_scheduler (line 360) | def prepare_scheduler(args, optimizer, num_batches):
  function main (line 383) | def main(args):

FILE: inf_clip/train/optims.py
  class ScalingViTAdafactor (line 8) | class ScalingViTAdafactor(Optimizer):
    method __init__ (line 18) | def __init__(
    method _get_lr (line 52) | def _get_lr(param_group, param_state):
    method _get_options (line 63) | def _get_options(param_group, param_shape):
    method _rms (line 69) | def _rms(tensor):
    method _approx_sq_grad (line 73) | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
    method step (line 81) | def step(self, closure=None):
  class Lion (line 179) | class Lion(Optimizer):
    method __init__ (line 184) | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
    method step (line 206) | def step(self, closure=None):

FILE: inf_clip/train/params.py
  function get_default_params (line 9) | def get_default_params(model_name):
  class ParseKwargs (line 18) | class ParseKwargs(argparse.Action):
    method __call__ (line 19) | def __call__(self, parser, namespace, values, option_string=None):
  function parse_args (line 30) | def parse_args(args):
  function create_deepspeed_config (line 507) | def create_deepspeed_config(args):

FILE: inf_clip/train/utils.py
  function setup_logging (line 19) | def setup_logging(log_file, level, include_host=False):
  function remote_sync_s3 (line 43) | def remote_sync_s3(local_dir, remote_dir):
  function remote_sync_fsspec (line 54) | def remote_sync_fsspec(local_dir, remote_dir):
  function remote_sync (line 79) | def remote_sync(local_dir, remote_dir, protocol):
  function keep_running_remote_sync (line 90) | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
  function start_sync_process (line 96) | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
  function pt_save (line 102) | def pt_save(pt_obj, file_path):
  function pt_load (line 108) | def pt_load(file_path, map_location=None):
  function check_exists (line 117) | def check_exists(file_path):
  function get_autocast (line 126) | def get_autocast(precision):
  function is_global_master (line 140) | def is_global_master(args):
  function is_local_master (line 144) | def is_local_master(args):
  function is_master (line 148) | def is_master(args, local=False):
  function is_using_horovod (line 152) | def is_using_horovod():
  function is_using_distributed (line 163) | def is_using_distributed():
  function world_info_from_env (line 171) | def world_info_from_env():
  function init_distributed_device (line 191) | def init_distributed_device(args):
  function broadcast_object (line 254) | def broadcast_object(args, obj, src=0):
  function all_gather_object (line 267) | def all_gather_object(args, obj, dst=0):

FILE: inf_clip/utils.py
  function freeze_batch_norm_2d (line 9) | def freeze_batch_norm_2d(module, module_match={}, name=''):
  function _ntuple (line 49) | def _ntuple(n):
  function replace_linear (line 65) | def replace_linear(model, linear_replacement, include_modules=['c_fc', '...
  function convert_int8_model_to_inference_mode (line 84) | def convert_int8_model_to_inference_mode(model):

FILE: inf_clip/zero_shot_classifier.py
  function batched (line 9) | def batched(iterable, n):
  function build_zero_shot_classifier (line 21) | def build_zero_shot_classifier(
  function build_zero_shot_classifier_legacy (line 71) | def build_zero_shot_classifier_legacy(

FILE: tests/example.py
  function create_cl_tensors (line 9) | def create_cl_tensors(rank, world_size):
Condensed preview — 152 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (516K chars).
[
  {
    "path": ".gitattributes",
    "chars": 61,
    "preview": "*.py linguist-language=python\n*.ipynb linguist-documentation\n"
  },
  {
    "path": ".gitignore",
    "chars": 2033,
    "preview": "**/logs/\n**/wandb/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Di"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 10782,
    "preview": "<p align=\"center\">\n    <img src=\"https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad\" width="
  },
  {
    "path": "inf_cl/__init__.py",
    "chars": 80,
    "preview": "from .flash import cal_flash_loss\nfrom .ring  import cal_ring_loss, cal_inf_loss"
  },
  {
    "path": "inf_cl/flash.py",
    "chars": 12883,
    "preview": "import math\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nimport triton\nimport triton.language as tl"
  },
  {
    "path": "inf_cl/ring.py",
    "chars": 12506,
    "preview": "import os\nimport math\nimport random\n\nimport torch\nimport torch.distributed as dist\nimport torch.distributed.nn as dist_n"
  },
  {
    "path": "inf_clip/__init__.py",
    "chars": 1269,
    "preview": "from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\n\nfrom .factory import create_model, create_model_and_tran"
  },
  {
    "path": "inf_clip/constants.py",
    "chars": 256,
    "preview": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\nIMAG"
  },
  {
    "path": "inf_clip/factory.py",
    "chars": 19619,
    "preview": "import json\nimport logging\nimport os\nimport re\nfrom copy import deepcopy\nfrom dataclasses import asdict\nfrom pathlib imp"
  },
  {
    "path": "inf_clip/model_configs/EVA01-g-14-plus.json",
    "chars": 401,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva_giant_patch14_22"
  },
  {
    "path": "inf_clip/model_configs/EVA01-g-14.json",
    "chars": 400,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva_giant_patch14_22"
  },
  {
    "path": "inf_clip/model_configs/EVA02-B-16.json",
    "chars": 404,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_base_patch16_cl"
  },
  {
    "path": "inf_clip/model_configs/EVA02-E-14-plus.json",
    "chars": 411,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_enormous_patch"
  },
  {
    "path": "inf_clip/model_configs/EVA02-E-14.json",
    "chars": 411,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_enormous_patch"
  },
  {
    "path": "inf_clip/model_configs/EVA02-L-14-336.json",
    "chars": 406,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"timm_model_name\": \"eva02_large_patch14_c"
  },
  {
    "path": "inf_clip/model_configs/EVA02-L-14.json",
    "chars": 406,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_large_patch14_c"
  },
  {
    "path": "inf_clip/model_configs/LiT-B-16.json",
    "chars": 483,
    "preview": "{\n    \"arch\": \"LiT-B-16\",\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\""
  },
  {
    "path": "inf_clip/model_configs/LiT-B-32.json",
    "chars": 483,
    "preview": "{\n    \"arch\": \"LiT-B-32\",\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\""
  },
  {
    "path": "inf_clip/model_configs/LiT-L-16.json",
    "chars": 487,
    "preview": "{\n    \"arch\": \"LiT-L-16\",\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name"
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-B.json",
    "chars": 483,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_base_mci_224\",\n        \"timm_model_pretraine"
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-S1.json",
    "chars": 476,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"fastvit_mci1\",\n        \"timm_model_pretrained\": "
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-S2.json",
    "chars": 476,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"fastvit_mci2\",\n        \"timm_model_pretrained\": "
  },
  {
    "path": "inf_clip/model_configs/RN101-quickgelu.json",
    "chars": 388,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n     "
  },
  {
    "path": "inf_clip/model_configs/RN101.json",
    "chars": 364,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,"
  },
  {
    "path": "inf_clip/model_configs/RN50-quickgelu.json",
    "chars": 389,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n    "
  },
  {
    "path": "inf_clip/model_configs/RN50.json",
    "chars": 364,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4"
  },
  {
    "path": "inf_clip/model_configs/RN50x16.json",
    "chars": 365,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"layers\": [\n            6,\n            8,"
  },
  {
    "path": "inf_clip/model_configs/RN50x4.json",
    "chars": 365,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": [\n            4,\n            6,"
  },
  {
    "path": "inf_clip/model_configs/RN50x64.json",
    "chars": 370,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 448,\n        \"layers\": [\n            3,\n            1"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-256.json",
    "chars": 710,
    "preview": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 2"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-384.json",
    "chars": 710,
    "preview": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 3"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-512.json",
    "chars": 710,
    "preview": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 5"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json",
    "chars": 720,
    "preview": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 2"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP.json",
    "chars": 710,
    "preview": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 2"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-plus-240.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 240,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-plus.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-quickgelu.json",
    "chars": 318,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-256.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-plus-256.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-quickgelu.json",
    "chars": 318,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-378-quickgelu.json",
    "chars": 348,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 378,\n        \"layers\": 32,\n  "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-CLIPA-336.json",
    "chars": 604,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-CLIPA.json",
    "chars": 604,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-quickgelu.json",
    "chars": 348,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n  "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14.json",
    "chars": 324,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-H-16.json",
    "chars": 324,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-280.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 280,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-336.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-CLIPA-336.json",
    "chars": 576,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-CLIPA.json",
    "chars": 576,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-quickgelu.json",
    "chars": 320,
    "preview": "{\n    \"embed_dim\": 768,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-320.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 320,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-SigLIP-256.json",
    "chars": 713,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-SigLIP-384.json",
    "chars": 713,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": "
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/ViT-M-16-alt.json",
    "chars": 325,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-M-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-M-32-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-M-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-S-16-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-S-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-S-32-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-S-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json",
    "chars": 743,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": "
  },
  {
    "path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP.json",
    "chars": 743,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": "
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json",
    "chars": 634,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 48,\n        \"width\": 1664,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14-CLIPA.json",
    "chars": 634,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14.json",
    "chars": 354,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-e-14.json",
    "chars": 354,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 56,\n        \"width\": 1792,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViT-g-14.json",
    "chars": 353,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 40,\n        \"width\": 1408,\n   "
  },
  {
    "path": "inf_clip/model_configs/ViTamin-B-LTT.json",
    "chars": 426,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_base_224\",\n      \"timm_model_pretrained\": "
  },
  {
    "path": "inf_clip/model_configs/ViTamin-B.json",
    "chars": 425,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_base_224\",\n      \"timm_model_pretrained\": "
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L-256.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_256\",\n      \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L-336.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_336\",\n      \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_224\",\n      \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2-256.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_256\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2-336.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_336\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_224\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-S-LTT.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_small_224\",\n      \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-S.json",
    "chars": 426,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_small_224\",\n      \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-256.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_256\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-336.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_336\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-384.json",
    "chars": 430,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_384\",\n      \"timm_model_pretrained"
  },
  {
    "path": "inf_clip/model_configs/coca_ViT-B-32.json",
    "chars": 659,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/coca_ViT-L-14.json",
    "chars": 664,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "inf_clip/model_configs/coca_base.json",
    "chars": 669,
    "preview": "{\n    \"embed_dim\": 512,\n    \"multimodal_cfg\": {\n        \"width\": 768,\n        \"context_length\": 76,\n        \"vocab_size\""
  },
  {
    "path": "inf_clip/model_configs/coca_roberta-ViT-B-32.json",
    "chars": 525,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/convnext_base.json",
    "chars": 421,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/convnext_base_w.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/convnext_base_w_320.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "inf_clip/model_configs/convnext_large.json",
    "chars": 423,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "inf_clip/model_configs/convnext_large_d.json",
    "chars": 420,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "inf_clip/model_configs/convnext_large_d_320.json",
    "chars": 420,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "inf_clip/model_configs/convnext_small.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_small\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "inf_clip/model_configs/convnext_tiny.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_tiny\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "inf_clip/model_configs/convnext_xlarge.json",
    "chars": 426,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xlarge\",\n        \"timm_model_pretraine"
  },
  {
    "path": "inf_clip/model_configs/convnext_xxlarge.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrain"
  },
  {
    "path": "inf_clip/model_configs/convnext_xxlarge_320.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrain"
  },
  {
    "path": "inf_clip/model_configs/mt5-base-ViT-B-32.json",
    "chars": 305,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/mt5-xl-ViT-H-14.json",
    "chars": 329,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-base-siglip.json",
    "chars": 509,
    "preview": "{\n    \"embed_dim\": 768,\n    \"custom_text\": true,\n    \"init_logit_bias\": -10,\n    \"vision_cfg\": {\n        \"image_size\": 3"
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-base.json",
    "chars": 371,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-large-siglip.json",
    "chars": 512,
    "preview": "{\n    \"embed_dim\": 1152,\n    \"custom_text\": true,\n    \"init_logit_bias\": -10,\n    \"vision_cfg\": {\n        \"image_size\": "
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-large.json",
    "chars": 399,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/model_configs/roberta-ViT-B-32.json",
    "chars": 323,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n   "
  },
  {
    "path": "inf_clip/model_configs/swin_base_patch4_window7_224.json",
    "chars": 380,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"swin_base_patch4_window7_224\",\n        \"timm_mod"
  },
  {
    "path": "inf_clip/model_configs/vit_medium_patch16_gap_256.json",
    "chars": 377,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_medium_patch16_gap_256\",\n        \"timm_model"
  },
  {
    "path": "inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json",
    "chars": 384,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_relpos_medium_patch16_cls_224\",\n        \"tim"
  },
  {
    "path": "inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json",
    "chars": 307,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json",
    "chars": 337,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "inf_clip/models/clip_arch.py",
    "chars": 24432,
    "preview": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimpo"
  },
  {
    "path": "inf_clip/models/coca_arch.py",
    "chars": 18994,
    "preview": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nf"
  },
  {
    "path": "inf_clip/models/hf_configs.py",
    "chars": 2422,
    "preview": "# HF architecture dict:\narch_dict = {\n    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n    \"robe"
  },
  {
    "path": "inf_clip/models/hf_model.py",
    "chars": 7597,
    "preview": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for u"
  },
  {
    "path": "inf_clip/models/lit_arch.py",
    "chars": 7051,
    "preview": "from dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\n"
  },
  {
    "path": "inf_clip/models/loss.py",
    "chars": 23890,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ntry:\n    import torch.distributed.nn\n    from t"
  },
  {
    "path": "inf_clip/models/modified_resnet.py",
    "chars": 7027,
    "preview": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom ..util"
  },
  {
    "path": "inf_clip/models/pos_embed.py",
    "chars": 4044,
    "preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
  },
  {
    "path": "inf_clip/models/timm_model.py",
    "chars": 6090,
    "preview": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower "
  },
  {
    "path": "inf_clip/models/tokenizer.py",
    "chars": 18171,
    "preview": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\ni"
  },
  {
    "path": "inf_clip/models/transform.py",
    "chars": 14341,
    "preview": "import numbers\nimport random\nimport warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, Lis"
  },
  {
    "path": "inf_clip/models/transformer.py",
    "chars": 34106,
    "preview": "from collections import OrderedDict\nimport math\nfrom typing import Callable, List, Optional, Sequence, Tuple, Union\nfrom"
  },
  {
    "path": "inf_clip/openai.py",
    "chars": 3314,
    "preview": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c"
  },
  {
    "path": "inf_clip/pretrained.py",
    "chars": 35037,
    "preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nimp"
  },
  {
    "path": "inf_clip/train/data.py",
    "chars": 21240,
    "preview": "import ast\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport braceexpand\nfrom dataclasses"
  },
  {
    "path": "inf_clip/train/engine.py",
    "chars": 27949,
    "preview": "import json\nimport logging\nimport math\nimport os\nimport time\nfrom contextlib import nullcontext\n\nimport numpy as np\nimpo"
  },
  {
    "path": "inf_clip/train/main.py",
    "chars": 22222,
    "preview": "import glob\nimport logging\nimport os\nimport re\nimport subprocess\nimport sys\nimport random\nimport time\nfrom functools imp"
  },
  {
    "path": "inf_clip/train/optims.py",
    "chars": 9273,
    "preview": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.optim import Optimizer\n\n\nclass ScalingViTAdafactor(Optimizer):"
  },
  {
    "path": "inf_clip/train/params.py",
    "chars": 21537,
    "preview": "import os\nimport argparse\nimport ast\nimport json\n\nfrom .utils import world_info_from_env\n\n\ndef get_default_params(model_"
  },
  {
    "path": "inf_clip/train/utils.py",
    "chars": 8893,
    "preview": "import os\nimport time\nimport logging\nimport subprocess\nimport multiprocessing\n\nimport torch\nimport torch.distributed as "
  },
  {
    "path": "inf_clip/utils.py",
    "chars": 3472,
    "preview": "from itertools import repeat\nimport collections.abc\n\nimport torch\nfrom torch import nn as nn\nfrom torchvision.ops.misc i"
  },
  {
    "path": "inf_clip/zero_shot_classifier.py",
    "chars": 4340,
    "preview": "from functools import partial\nfrom itertools import islice\nfrom typing import Callable, List, Optional, Sequence, Union\n"
  },
  {
    "path": "inf_clip/zero_shot_metadata.py",
    "chars": 19245,
    "preview": "\nOPENAI_IMAGENET_TEMPLATES = (\n    lambda c: f'a bad photo of a {c}.',\n    lambda c: f'a photo of many {c}.',\n    lambda"
  },
  {
    "path": "pyproject.toml",
    "chars": 1394,
    "preview": "[build-system]\nrequires = [\"pdm-backend\"]\nbuild-backend = \"pdm.backend\"\n\n[project]\nname = \"inf-cl\"\nversion = \"1.2\"\nautho"
  },
  {
    "path": "requirements.txt",
    "chars": 446,
    "preview": "--extra-index-url https://download.pytorch.org/whl/cu118\n# basic dependencies\ntorch==2.2.0\ntorchvision==0.17.0\nnumpy==1."
  },
  {
    "path": "scripts/benchmarks_eval.sh",
    "chars": 293,
    "preview": "clip_benchmark eval \\\n    --model LiT-B-16 \\\n    --pretrained work_dirs/epoch_8.pt \\\n    --dataset datasets/imagenet.txt"
  },
  {
    "path": "scripts/cc12m/clip_vit-b-32_bs32k.sh",
    "chars": 1862,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/cc12m/lit_vit-b-16_bs32k.sh",
    "chars": 1931,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/cc12m/lit_vit-b-32_bs32k.sh",
    "chars": 1931,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/cc3m/clip_r50_bs4k.sh",
    "chars": 1874,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/cc3m/clip_vit-b-32_bs16k.sh",
    "chars": 1890,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/cc3m/lit_vit-b-32_bs16k.sh",
    "chars": 1927,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/imagenet_eval.sh",
    "chars": 171,
    "preview": "torchrun --nproc_per_node 1 \\\n    -m inf_cl_train.main \\\n    --imagenet-val datasets/imagenet-1k/val \\\n    --model ViT-B"
  },
  {
    "path": "scripts/laion400m/clip_vit-b-32_bs256k.sh",
    "chars": 1821,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/laion400m/lit_vit-b-16_bs256k.sh",
    "chars": 1941,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/laion400m/lit_vit-b-32_bs256k.sh",
    "chars": 1941,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "scripts/laion400m/lit_vit-l-16_bs256k.sh",
    "chars": 1941,
    "preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
  },
  {
    "path": "tests/example.py",
    "chars": 1496,
    "preview": "import torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport numpy as np\n\nfrom inf_cl import cal"
  }
]

About this extraction

This page contains the full source code of the DAMO-NLP-SG/Inf-CLIP GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 152 files (471.8 KB), approximately 130.5k tokens, and a symbol index with 447 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!