Showing preview only (512K chars total). Download the full file or copy to clipboard to get everything.
Repository: DAMO-NLP-SG/Inf-CLIP
Branch: main
Commit: d9f2833b3753
Files: 152
Total size: 471.8 KB
Directory structure:
gitextract_u47kktxp/
├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── inf_cl/
│ ├── __init__.py
│ ├── flash.py
│ └── ring.py
├── inf_clip/
│ ├── __init__.py
│ ├── constants.py
│ ├── factory.py
│ ├── model_configs/
│ │ ├── EVA01-g-14-plus.json
│ │ ├── EVA01-g-14.json
│ │ ├── EVA02-B-16.json
│ │ ├── EVA02-E-14-plus.json
│ │ ├── EVA02-E-14.json
│ │ ├── EVA02-L-14-336.json
│ │ ├── EVA02-L-14.json
│ │ ├── LiT-B-16.json
│ │ ├── LiT-B-32.json
│ │ ├── LiT-L-16.json
│ │ ├── MobileCLIP-B.json
│ │ ├── MobileCLIP-S1.json
│ │ ├── MobileCLIP-S2.json
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN50x64.json
│ │ ├── ViT-B-16-SigLIP-256.json
│ │ ├── ViT-B-16-SigLIP-384.json
│ │ ├── ViT-B-16-SigLIP-512.json
│ │ ├── ViT-B-16-SigLIP-i18n-256.json
│ │ ├── ViT-B-16-SigLIP.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16-quickgelu.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-256.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14-378-quickgelu.json
│ │ ├── ViT-H-14-CLIPA-336.json
│ │ ├── ViT-H-14-CLIPA.json
│ │ ├── ViT-H-14-quickgelu.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14-CLIPA-336.json
│ │ ├── ViT-L-14-CLIPA.json
│ │ ├── ViT-L-14-quickgelu.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16-SigLIP-256.json
│ │ ├── ViT-L-16-SigLIP-384.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-M-16-alt.json
│ │ ├── ViT-M-16.json
│ │ ├── ViT-M-32-alt.json
│ │ ├── ViT-M-32.json
│ │ ├── ViT-S-16-alt.json
│ │ ├── ViT-S-16.json
│ │ ├── ViT-S-32-alt.json
│ │ ├── ViT-S-32.json
│ │ ├── ViT-SO400M-14-SigLIP-384.json
│ │ ├── ViT-SO400M-14-SigLIP.json
│ │ ├── ViT-bigG-14-CLIPA-336.json
│ │ ├── ViT-bigG-14-CLIPA.json
│ │ ├── ViT-bigG-14.json
│ │ ├── ViT-e-14.json
│ │ ├── ViT-g-14.json
│ │ ├── ViTamin-B-LTT.json
│ │ ├── ViTamin-B.json
│ │ ├── ViTamin-L-256.json
│ │ ├── ViTamin-L-336.json
│ │ ├── ViTamin-L.json
│ │ ├── ViTamin-L2-256.json
│ │ ├── ViTamin-L2-336.json
│ │ ├── ViTamin-L2.json
│ │ ├── ViTamin-S-LTT.json
│ │ ├── ViTamin-S.json
│ │ ├── ViTamin-XL-256.json
│ │ ├── ViTamin-XL-336.json
│ │ ├── ViTamin-XL-384.json
│ │ ├── coca_ViT-B-32.json
│ │ ├── coca_ViT-L-14.json
│ │ ├── coca_base.json
│ │ ├── coca_roberta-ViT-B-32.json
│ │ ├── convnext_base.json
│ │ ├── convnext_base_w.json
│ │ ├── convnext_base_w_320.json
│ │ ├── convnext_large.json
│ │ ├── convnext_large_d.json
│ │ ├── convnext_large_d_320.json
│ │ ├── convnext_small.json
│ │ ├── convnext_tiny.json
│ │ ├── convnext_xlarge.json
│ │ ├── convnext_xxlarge.json
│ │ ├── convnext_xxlarge_320.json
│ │ ├── mt5-base-ViT-B-32.json
│ │ ├── mt5-xl-ViT-H-14.json
│ │ ├── nllb-clip-base-siglip.json
│ │ ├── nllb-clip-base.json
│ │ ├── nllb-clip-large-siglip.json
│ │ ├── nllb-clip-large.json
│ │ ├── roberta-ViT-B-32.json
│ │ ├── swin_base_patch4_window7_224.json
│ │ ├── vit_medium_patch16_gap_256.json
│ │ ├── vit_relpos_medium_patch16_cls_224.json
│ │ ├── xlm-roberta-base-ViT-B-32.json
│ │ └── xlm-roberta-large-ViT-H-14.json
│ ├── models/
│ │ ├── clip_arch.py
│ │ ├── coca_arch.py
│ │ ├── hf_configs.py
│ │ ├── hf_model.py
│ │ ├── lit_arch.py
│ │ ├── loss.py
│ │ ├── modified_resnet.py
│ │ ├── pos_embed.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ └── transformer.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── train/
│ │ ├── data.py
│ │ ├── engine.py
│ │ ├── main.py
│ │ ├── optims.py
│ │ ├── params.py
│ │ └── utils.py
│ ├── utils.py
│ ├── zero_shot_classifier.py
│ └── zero_shot_metadata.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── benchmarks_eval.sh
│ ├── cc12m/
│ │ ├── clip_vit-b-32_bs32k.sh
│ │ ├── lit_vit-b-16_bs32k.sh
│ │ └── lit_vit-b-32_bs32k.sh
│ ├── cc3m/
│ │ ├── clip_r50_bs4k.sh
│ │ ├── clip_vit-b-32_bs16k.sh
│ │ └── lit_vit-b-32_bs16k.sh
│ ├── imagenet_eval.sh
│ └── laion400m/
│ ├── clip_vit-b-32_bs256k.sh
│ ├── lit_vit-b-16_bs256k.sh
│ ├── lit_vit-b-32_bs256k.sh
│ └── lit_vit-l-16_bs256k.sh
└── tests/
└── example.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
*.py linguist-language=python
*.ipynb linguist-documentation
================================================
FILE: .gitignore
================================================
**/logs/
**/wandb/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
sync.sh
gpu1sync.sh
.idea
*.pdf
**/._*
**/*DS_*
**.jsonl
src/sbatch
src/misc
.vscode
src/debug
core.*
# Allow
!src/evaluation/misc/results_dbs/*
# log dirs
/work_dirs*/
/datasets/
# oss logs
/.ossutil*
/ossutil*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
<p align="center">
<img src="https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad" width="150" style="margin-bottom: 0.2;"/>
<p>
<h3 align="center"><a href="https://arxiv.org/abs/2410.17243">
Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss</a></h3>
<h5 align="center"> If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏 </h2>
<h5 align="center">
[](https://arxiv.org/abs/2410.17243)
[](https://huggingface.co/papers/2410.17243)
[](https://pypi.org/project/inf-cl) <br>
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/LICENSE)
[](https://hits.seeyoufarm.com)
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aopen+is%3Aissue)
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aissue+is%3Aclosed) <br>
[](https://zhuanlan.zhihu.com/p/1681887214)
[](https://x.com/lixin4ever/status/1849669129613226457) <br>
</h5>
<div align="center"><img src="https://github.com/user-attachments/assets/2c19838b-43d8-4145-b28c-903f3d76f8ab" width="800" /></div>
<details open><summary>💡 Some other multimodal foundation model projects from our team may interest you ✨. </summary><p>
<!-- may -->
> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922) <br>
> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing <br>
[](https://github.com/DAMO-NLP-SG/VCD) [](https://github.com/DAMO-NLP-SG/VCD) [](https://arxiv.org/abs/2311.16922) <br>
> [**VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs**](https://github.com/DAMO-NLP-SG/VideoLLaMA2) <br>
> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing <br>
[](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [](https://arxiv.org/abs/2406.07476) <br>
> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787) <br>
> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing <br>
[](https://github.com/DAMO-NLP-SG/CMM) [](https://github.com/DAMO-NLP-SG/CMM) [](https://arxiv.org/abs/2410.12787) <br>
</p></details>
## 📰 News
* **[2024.10.18]** Release training and evaluation codes of Inf-CLIP.
<div align="center"><img src="https://github.com/user-attachments/assets/11c5cc32-aac2-497d-bbc1-33e065a71be0" width="800" /></div>
## 🛠️ Requirements and Installation
Basic Dependencies:
* Python >= 3.8
* Pytorch >= 2.0.0
* CUDA Version >= 11.8
[Remote] Install Inf-CL:
```bash
# remote installing
pip install inf_cl -i https://pypi.org/simple
```
[Local] Install Inf-CL:
```bash
pip install -e .
```
Install required packages:
```bash
git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt
```
## ⭐ Features
`inf_cl` is the triton implementation of Inf-CL loss:
* [x] [Ring-CL (inf_cl/ring.py#L238)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L238)
* [x] [Inf-CL (inf_cl/ring.py#L251)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L251)
`inf_clip` is the CLIP training codebase with Inf-CL loss and other training features:
- [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](https://github.com/DAMO-NLP-SG/Inf-CLIP/inf_clip_train/train.py#L180)
- [x] [Gradient Cache (inf_clip/train/train.py#L292)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip_train/train.py#L292)
## 🔑 Usage
A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:
```
torchrun --nproc_per_node 2 tests/example.py
```
```python
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from inf_cl import cal_inf_loss
def create_cl_tensors(rank, world_size):
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
l = l.requires_grad_() # Logit scale
return q, k, l
if __name__ == "__main__":
# Assume that the distributed environment has been initialized
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
# Exampled by Image-Text Contrastive Learning, q is the global image features,
# k is the text features, and l is the logit scale.
q, k, l = create_cl_tensors(rank, world_size)
# labels are diagonal elements by default.
# labels = torch.arange(q.shape[0])
loss = cal_inf_loss(q, k, scale=l.exp())
print(loss)
```
## 🚀 Main Results
### Memory Cost
<p><img src="https://github.com/user-attachments/assets/05dd3fea-0a93-4716-b321-0a94965e1fbe" width="800" "/></p>
\* denotes adopting "data offload" strategy.
### Max Supported Batch Size
<p><img src="https://github.com/user-attachments/assets/eb38fb90-3b7e-4696-b078-b7766893f758" width="800" "/></p>
### Speed
<p><img src="https://github.com/user-attachments/assets/da72e99b-508b-450a-b12e-401d4991291a" width="800" "/></p>
### Batch Size Scaling
<p><img src="https://github.com/user-attachments/assets/5b55fa98-6558-4509-9b66-e290ecf77b41" width="800" "/></p>
Training with larger data scale needs larger batch size.
## 🗝️ Training & Evaluation
### Quick Start
To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.
1. Training Data Structure:
```bash
Inf-CLIP
├── datasets
│ ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 0301.tar
│ ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 1044.tar
│ ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
| | ├── 00000.tar
| | ├── 00001.tar
| | ├── ...
| | └── 41407.tar
```
2. Command:
```bash
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
```
3. Evaluation Data Structure:
```bash
Inf-CLIP
├── datasets
│ ├── imagenet-1k/ # download val_images.tar.gz of imagenet from https://huggingface.co/datasets/ILSVRC/imagenet-1k/tree/main/data
| | └── val/ # python datasets/reformat_imagenet.py
| | | ├── n01440764
| | | ├── n01443537
| | | ├── ...
| | | └── n15075141
│ ├── clip-benchmark/ # bash datasets/benchmarks_download.sh
| | ├── wds_mscoco_captions
| | ├── wds_flickr8k
| | ├── wds_flickr30k
| | ├── wds_imagenet1k
| | ├── wds_imagenetv2
| | ├── wds_imagenet_sketch
| | ├── wds_imagenet-a
| | ├── wds_imagenet-r
| | ├── wds_imagenet-o
| | └── wds_objectnet
```
4. Command:
```bash
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh
```
## 📑 Citation
If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:
```bibtex
@article{damovl2024infcl,
title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
journal={arXiv preprint arXiv:2410.17243},
year={2024},
url={https://arxiv.org/abs/2410.12787}
}
```
## 👍 Acknowledgement
The codebase of Inf-CLIP is adapted from [**OpenCLIP**](https://github.com/mlfoundations/open_clip). We are also grateful for the following projects our Inf-CL arose from:
* [**OpenAI CLIP**](https://openai.com/index/clip/), [**img2dataset**](https://github.com/rom1504/img2dataset), [**CLIP-Benchmark**](https://github.com/LAION-AI/CLIP_benchmark).
* [**FlashAttention**](https://github.com/Dao-AILab/flash-attention), [**RingAttention**](https://github.com/haoliuhl/ringattention), [**RingFlashAttention**](https://github.com/zhuzilin/ring-flash-attention).
## 🔒 License
This project is released under the Apache 2.0 license as found in the LICENSE file.
The service is a research preview intended for **non-commercial use ONLY**, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.
================================================
FILE: inf_cl/__init__.py
================================================
from .flash import cal_flash_loss
from .ring import cal_ring_loss, cal_inf_loss
================================================
FILE: inf_cl/flash.py
================================================
import math
import torch
import torch.nn.functional as F
import numpy as np
import triton
import triton.language as tl
@triton.jit
def _prob_fwd_kernel(
Q,
K,
LSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# start index of sequence length
start_m = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
# initialize pointer to m and l
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# loop over k, v and update accumulator
end_n = seqlen_k
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
m_ij = tl.maximum(tl.max(qk, 1), m_i)
p = tl.exp(qk - m_ij[:, None])
# Fix out of bound access
p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0)
# -- update statistics
lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1)
m_i = m_ij
lse_i = m_i + tl.log(lse_i)
# mask out the padded values
lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0)
tl.store(LSE + offs_m, lse_i)
@triton.jit
def _dq_prob_bwd_kernel(
Q,
K,
dQ,
LSE,
dLSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
start_m = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
dq_ptrs = dQ + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
# setting lse
lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0)
dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute q grad ----
# NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch
# A solution for this problem
# Refer to issue: https://github.com/triton-lang/triton/issues/4574
# if allow_tf32:
k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1)
q_grad = tl.dot(qk_grad, k)
# Another solution for this problem
# Refer to https://github.com/triton-lang/triton/issues/376
# q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)
# -- store dq ----
dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q)
@triton.jit
def _dk_prob_bwd_kernel(
Q,
K,
dK,
LSE,
dLSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
start_n = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
dk_ptrs = dK + ndims * offs_n[:, None]
# loop over q and update accumulator
end_m = seqlen_q
for start_m in range(0, end_m, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
# setting lse
lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)
dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute k grad ----
q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1)
k_grad = tl.dot(tl.trans(qk_grad), q)
# k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32))
# -- store dk ----
dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k)
def _flash_prob_forward(q, k):
# shape constraints
seqlen_q, nheads, d = q.shape
seqlen_k, _, _ = k.shape
assert k.shape == (seqlen_k, nheads, d)
# assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype, "All tensors must have the same type"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((seqlen_q_rounded), device=q.device, dtype=torch.float32)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK_M = 64
BLOCK_N = 64
num_warps = 8
num_stages = 1
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
_prob_fwd_kernel[grid](
q,
k,
lse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
lse = lse[:seqlen_q]
return lse
def _flash_prob_backward(q, k, lse, dlse):
# shape constraints
seqlen_q, nheads, d = q.shape
seqlen_k, _, _ = k.shape
assert k.shape == (seqlen_k, nheads, d)
# assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype, "All tensors must have the same type"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
q = q.contiguous()
k = k.contiguous()
dlse = dlse.contiguous()
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK_M = 64
BLOCK_N = 64
num_warps = 8
num_stages = 1
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
_dq_prob_bwd_kernel[grid](
q,
k,
dq,
lse,
dlse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
BLOCK_N = BLOCK_M
BLOCK_M = BLOCK_N
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), 1)
_dk_prob_bwd_kernel[grid](
q,
k,
dk,
lse,
dlse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
dq = dq[:seqlen_q]
dk = dk[:seqlen_k]
return dq, dk
class FlashProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k):
lse = _flash_prob_forward(q, k)
ctx.save_for_backward(q, k, lse)
return lse
@staticmethod
def backward(ctx, dlse):
q, k, lse = ctx.saved_tensors
dq, dk = _flash_prob_backward(q, k, lse, dlse)
return dq, dk
def _cal_flash_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
# NOTE: logits forward or backward should keep fp32 for better precision
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = FlashProb.apply(q, k)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):
if labels is None:
labels = torch.arange(q.shape[0], device=q.device)
if scale is None:
scale = 1.0
return _cal_flash_loss(scale * q, k, labels, head_dim)
if __name__ == '__main__':
import time
# Parameters
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q, num_heads * d_model), dtype=torch.float32, device="cuda") # Query
k = torch.rand((seq_length_k, num_heads * d_model), dtype=torch.float32, device="cuda") # Key
l = torch.ones([], device="cuda") * np.log(1 / 0.02); l.requires_grad = True
q = F.normalize(q, p=2, dim=-1); q.requires_grad = True
k = F.normalize(k, p=2, dim=-1); k.requires_grad = True
q1 = q.clone().detach().requires_grad_(True)
k1 = k.clone().detach().requires_grad_(True)
l1 = l.clone().detach().requires_grad_(True)
labels = torch.arange(seq_length_q).cuda()
for i in range(1000):
# A. torch gradient
start = time.time()
qk = torch.einsum("md,nd->mn", l.exp() * q, k)
loss = F.cross_entropy(qk, labels, reduction="mean")
loss.backward()
end = time.time()
# B. triton gradient
start1 = time.time()
loss1 = cal_flash_loss(q1, k1, labels, l1.exp())
loss1 = loss1.mean()
loss1.backward()
end1 = time.time()
print("========= Difference =========")
print(end - start, end1 - start1, l.grad, l1.grad)
print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))
q.grad = None; k.grad = None; l.grad = None
q1.grad = None; k1.grad = None; l1.grad = None
================================================
FILE: inf_cl/ring.py
================================================
import os
import math
import random
import torch
import torch.distributed as dist
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
import numpy as np
import triton
import triton.language as tl
from .flash import _flash_prob_forward, _flash_prob_backward, _cal_flash_loss
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
# print(f'rank: {self.rank}, send_rank: {self.send_rank}, recv_rank: {self.recv_rank}')
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(self, to_send, recv_tensor = None):
if recv_tensor is None:
res = torch.empty_like(to_send)
else:
res = recv_tensor
send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
class GradientGather(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, dx):
dist.all_reduce(dx)
return dx
class RingProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, group):
rank = dist.get_rank()
k = k.contiguous()
comm = RingComm(group)
colle = [q, k]
lse = None
next_k = None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
comm.commit()
# vanilla lse
qk = torch.einsum("mhd,nhd->mn", q, k)
block_lse = torch.log(torch.exp(qk).sum(dim=-1))
if step == 0:
lse = block_lse
else:
lse = lse - F.logsigmoid(lse - block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
# this should be out_padded
colle.append(lse)
ctx.save_for_backward(*colle)
ctx.group = group
return lse
@staticmethod
def backward(ctx, dlse):
rank = dist.get_rank()
q, k, lse = ctx.saved_tensors
k_comm = RingComm(ctx.group)
d_k_comm = RingComm(ctx.group)
dq, dk = None, None
next_dk = None
block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)
next_dk, next_k = None, None
for step in range(k_comm.world_size):
if step + 1 != k_comm.world_size:
next_k = k_comm.send_recv(k)
k_comm.commit()
# vanilla gradient calculation
qk = torch.einsum("mhd,nhd->mn", q, k)
qk_grad = torch.exp(qk - lse[:, None]).float()
qk_grad = qk_grad * dlse[:, None]
block_dq_buffer = torch.einsum("mn,nhd->mhd", qk_grad, k.float())
block_dk_buffer = torch.einsum("nm,mhd->nhd", qk_grad.T, q.float())
if step == 0:
dq = block_dq_buffer
dk = block_dk_buffer
else:
dq += block_dq_buffer
d_k_comm.wait()
dk = block_dk_buffer + next_dk
if step + 1 != k_comm.world_size:
k_comm.wait()
k = next_k
next_dk = d_k_comm.send_recv(dk)
d_k_comm.commit()
d_k_comm.wait()
return dq, next_dk, None
class InfProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, group):
rank = dist.get_rank()
k = k.contiguous()
comm = RingComm(group)
colle = [q, k]
lse = None
next_k = None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
comm.commit()
# flash lse
block_lse = _flash_prob_forward(q, k)
if step == 0:
lse = block_lse
else:
lse = lse - F.logsigmoid(lse - block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
# this should be out_padded
colle.append(lse)
ctx.save_for_backward(*colle)
ctx.group = group
return lse
@staticmethod
def backward(ctx, dlse):
rank = dist.get_rank()
q, k, lse = ctx.saved_tensors
k_comm = RingComm(ctx.group)
d_k_comm = RingComm(ctx.group)
dq, dk = None, None
next_dk = None
block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)
next_dk, next_k = None, None
for step in range(k_comm.world_size):
if step + 1 != k_comm.world_size:
next_k = k_comm.send_recv(k)
k_comm.commit()
# flash gradient calculation
block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse)
if step == 0:
dq = block_dq_buffer
dk = block_dk_buffer
else:
dq += block_dq_buffer
d_k_comm.wait()
dk = block_dk_buffer + next_dk
if step + 1 != k_comm.world_size:
k_comm.wait()
k = next_k
next_dk = d_k_comm.send_recv(dk)
d_k_comm.commit()
d_k_comm.wait()
return dq, next_dk, None
def set_seed(rank, seed=42):
seed = rank + seed
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _cal_ring_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = RingProb.apply(q, k, None)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def _cal_inf_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = InfProb.apply(q, k, None)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):
"""The triton implementation of the ring-cl.
Args:
q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.
"""
if labels is None:
labels = torch.arange(q.shape[0]).to(q.device)
if scale is None:
scale = 1.0
else:
scale = GradientGather.apply(scale)
if torch.distributed.is_initialized():
return _cal_ring_loss(scale * q, k, labels, head_dim).mean()
else:
return _cal_flash_loss(scale * q, k, labels, head_dim).mean()
def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):
"""The triton implementation of the inf-cl.
Args:
q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.
"""
if labels is None:
labels = torch.arange(q.shape[0]).to(q.device)
if scale is None:
scale = 1.0
else:
scale = GradientGather.apply(scale)
if torch.distributed.is_initialized():
return _cal_inf_loss(scale * q, k, labels, head_dim).mean()
else:
return _cal_flash_loss(scale * q, k, labels, head_dim).mean()
if __name__ == "__main__":
import time
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(f'cuda:{os.environ["LOCAL_RANK"]}')
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
l = torch.ones([], dtype=dtype, device="cuda") * np.log(1 / 0.07); l = l.requires_grad_() # Logit scale
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
q1 = q.clone().detach().requires_grad_()
k1 = k.clone().detach().requires_grad_()
l1 = l.clone().detach().requires_grad_()
for i in range(1000):
# A. local torch gradient
start = time.time()
# A.1. gather q, k
gathered_q = [torch.zeros_like(q) for _ in range(world_size)]
gathered_k = [torch.zeros_like(k) for _ in range(world_size)]
dist.all_gather(gathered_q, q)
dist.all_gather(gathered_k, k)
gathered_q[rank] = q
gathered_k[rank] = k
all_q = torch.cat(gathered_q, dim=0)
all_k = torch.cat(gathered_k, dim=0)
# A.2. calculating qk logits
qk = torch.einsum("md,nd->mn", l.exp() * all_q, all_k)
kq = qk.T
_labels = torch.arange(seq_length_q).to(q.device)
# A.3. calculating loss
loss_i2t = F.cross_entropy(qk, _labels, reduction="mean")
loss_t2i = F.cross_entropy(kq, _labels, reduction="mean")
# A.4. scaling loss to normal value
scale_factor = (all_q.shape[0] / q.shape[0])
loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor
loss.backward()
show_loss = loss.detach().clone()
dist.all_reduce(show_loss)
show_loss = show_loss / (world_size * scale_factor)
end = time.time()
dist.barrier()
# B. triton implementation
start1 = time.time()
# labels = torch.arange(seq_length_q // world_size).to(q.device)
loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp())
loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp())
loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5
loss1.backward()
end1 = time.time()
dist.barrier()
if rank == 0:
print(rank, end - start, end1 - start1, loss, show_loss, loss1)
print(l.grad, l1.grad, torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))
q.grad = None; k.grad = None; l.grad = None
q1.grad = None; k1.grad = None; l1.grad = None
================================================
FILE: inf_clip/__init__.py
================================================
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .models.tokenizer import SimpleTokenizer, tokenize, decode
from .models.transform import image_transform, AugmentationCfg
from .models.coca_arch import CoCa
from .models.clip_arch import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
================================================
FILE: inf_clip/constants.py
================================================
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
================================================
FILE: inf_clip/factory.py
================================================
import json
import logging
import os
import re
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf, convert_state_dict
from .models.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .models.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .models.clip_arch import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .models.coca_arch import CoCa
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, FlashClipLoss, RingClipLoss, InfClipLoss, DiscoClipLoss
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
**kwargs,
):
if model_name.startswith(HF_HUB_PREFIX):
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)['model_cfg']
except Exception:
tokenizer = HFTokenizer(
model_name,
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
**kwargs,
)
return tokenizer
else:
config = get_model_config(model_name)
assert config is not None, f"No valid model config found for {model_name}."
text_config = config.get('text_cfg', {})
if 'tokenizer_kwargs' in text_config:
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
else:
tokenizer_kwargs = kwargs
if context_length is None:
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.pretrained import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path)
# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)
# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = False,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None and force_patch_dropout != False:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None and force_image_size != False:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
# NOTE: disable pretrained_hf arguments.
# model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
model_cfg['text_cfg']['hf_model_pretrained'] = model_cfg['text_cfg']['hf_model_pretrained']
# and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
model_arch = model_cfg.pop("arch", "CLIP")
if custom_text:
if "CLIP" in model_arch:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
elif "LiT" in model_arch:
model = LiT(**model_cfg, cast_dtype=cast_dtype)
elif "CoCa" in model_arch or "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_loss(args):
if args.distill_model:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.siglip:
assert not args.horovod, "Horovod not currently supported for SigLip"
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
)
elif args.flashloss:
return FlashClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.ringloss:
return RingClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
elif args.infloss:
return InfClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
elif args.discoloss:
return DiscoClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = False,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
**model_kwargs,
)
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
preprocess_train = image_transform_v2(
pp_cfg,
is_train=True,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform_v2(
pp_cfg,
is_train=False,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
cache_dir=cache_dir,
require_pretrained=True,
**model_kwargs,
)
if not return_transform:
return model
preprocess = image_transform_v2(
PreprocessCfg(**model.visual.preprocess_cfg),
is_train=False,
)
return model, preprocess
================================================
FILE: inf_clip/model_configs/EVA01-g-14-plus.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA01-g-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-B-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_base_patch16_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-E-14-plus.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-E-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-L-14-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"timm_model_name": "eva02_large_patch14_clip_336",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_large_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/LiT-B-16.json
================================================
{
"arch": "LiT-B-16",
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch16_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-base-uncased",
"hf_model_name": "bert-base-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/LiT-B-32.json
================================================
{
"arch": "LiT-B-32",
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch32_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-base-uncased",
"hf_model_name": "bert-base-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/LiT-L-16.json
================================================
{
"arch": "LiT-L-16",
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_large_patch16_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-large-uncased",
"hf_model_name": "bert-large-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-B.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_mci_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": false
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-S1.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci1",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-S2.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci2",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/RN101-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN101.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 384,
"layers": [
6,
8,
18,
8
],
"width": 96,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x4.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 288,
"layers": [
4,
6,
10,
6
],
"width": 80,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x64.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 448,
"layers": [
3,
15,
36,
10
],
"width": 128,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-256.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_base_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-384.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_base_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-512.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 512,
"timm_model_name": "vit_base_patch16_siglip_512",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_base_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 250000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch16_siglip_224",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-plus-240.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 240,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-plus.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-256.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 256,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-plus-256.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 256,
"layers": 12,
"width": 896,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-378-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 378,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA-336.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 336,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1024,
"heads": 16,
"layers": 24,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1024,
"heads": 16,
"layers": 24,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-16.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-280.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 280,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-quickgelu.json
================================================
{
"embed_dim": 768,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 320,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-256.json
================================================
{
"embed_dim": 1024,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_large_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1024,
"heads": 16,
"layers": 24,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-384.json
================================================
{
"embed_dim": 1024,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_large_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1024,
"heads": 16,
"layers": 24,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-16-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16,
"ls_init_value": 1e-4
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-32-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-16-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-16.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-32-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-32.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json
================================================
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_so400m_patch14_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP.json
================================================
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_so400m_patch14_siglip_224",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 16,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 336,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1280,
"heads": 20,
"layers": 32,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1280,
"heads": 20,
"layers": 32,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
}
}
================================================
FILE: inf_clip/model_configs/ViT-e-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 56,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.5715,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 36
}
}
================================================
FILE: inf_clip/model_configs/ViT-g-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViTamin-B-LTT.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-B.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L-256.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2-256.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2-336.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-S-LTT.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_small_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-S.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"timm_model_name": "vitamin_small_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-256.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-336.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-384.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_384",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"attn_pooler_heads": 8
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"attn_pooler_heads": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_base.json
================================================
{
"embed_dim": 512,
"multimodal_cfg": {
"width": 768,
"context_length": 76,
"vocab_size": 64000,
"mlp_ratio": 4,
"layers": 12,
"dim_head": 64,
"heads": 12,
"n_queries": 256,
"attn_pooler_heads": 8
},
"vision_cfg": {
"image_size": 288,
"layers": 12,
"width": 768,
"patch_size": 18,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 64000,
"layers": 12,
"heads": 12,
"width": 768,
"embed_cls": true,
"output_tokens": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"output_tokens": true
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"hf_proj_type": "linear",
"width": 768,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"width": 768,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/convnext_base.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_base_w.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_base_w_320.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_large.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_large_d.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: inf_clip/model_configs/convnext_large_d_320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: inf_clip/model_configs/convnext_small.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_small",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_tiny.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_tiny",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_xlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 20
}
}
================================================
FILE: inf_clip/model_configs/convnext_xxlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/convnext_xxlarge_320.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/mt5-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "google/mt5-base",
"hf_tokenizer_name": "google/mt5-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/mt5-xl-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "google/mt5-xl",
"hf_tokenizer_name": "google/mt5-xl",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-base-siglip.json
================================================
{
"embed_dim": 768,
"custom_text": true,
"init_logit_bias": -10,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_base_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-600M",
"hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-base.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-600M",
"hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-large-siglip.json
================================================
{
"embed_dim": 1152,
"custom_text": true,
"init_logit_bias": -10,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_so400m_patch14_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-1.3B",
"hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-large.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-1.3B",
"hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/swin_base_patch4_window7_224.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "swin_base_patch4_window7_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/vit_medium_patch16_gap_256.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_medium_patch16_gap_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "xlm-roberta-base",
"hf_tokenizer_name": "xlm-roberta-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "xlm-roberta-large",
"hf_tokenizer_name": "xlm-roberta-large",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/models/clip_arch.py
================================================
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial
from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
text_global_pool
from ..utils import to_2tuple
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'learnable'
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'tok'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
embed_cls: bool = False
pad_id: int = 0
no_causal_mask: bool = False # disable causal masking
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = 'mlp'
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
attentional_pool=vision_cfg.attentional_pool,
attn_pooler_queries=vision_cfg.attn_pooler_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
pos_embed_type=vision_cfg.pos_embed_type,
no_ln_pre=vision_cfg.no_ln_pre,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
pool_type=vision_cfg.pool_type,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj_type=text_cfg.hf_proj_type,
pooler_type=text_cfg.hf_pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
mlp_ratio=text_cfg.mlp_ratio,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
pool_type=text_cfg.pool_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
arch_type: torch.jit.Final[str] = 'clip'
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.context_length = text.context_length
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.text_pool_type = text.pool_type
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x, _ = text_global_pool(x, text, self.text_pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
arch_type: torch.jit.Final[str] = 'clip'
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None:
return
# FIXME add support for text cls_token
model_pos_embed = getattr(model, 'positional_embedding', None)
if model_pos_embed is None:
model_pos_embed = getattr(model.text, 'positional_embedding', None)
old_num_pos = old_pos_embed.shape[0]
old_width = old_pos_embed.shape[1]
num_pos = model_pos_embed.shape[0]
width = model_pos_embed.shape[1]
assert old_width == width, 'text pos_embed width changed!'
if old_num_pos == num_pos:
return
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
old_pos_embed = F.interpolate(
old_pos_embed,
size=num_pos,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
new_pos_embed = old_pos_embed
state_dict['positional_embedding'] = new_pos_embed
def get_model_preprocess_cfg(model):
module = getattr(model, 'visual', model)
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
if not preprocess_cfg:
# use separate legacy attributes if preprocess_cfg dict not found
size = getattr(module, 'image_size')
if size is not None:
preprocess_cfg['size'] = size
mean = getattr(module, 'image_mean', None)
if mean is not None:
preprocess_cfg['mean'] = mean
std = getattr(module, 'image_std', None)
if std is not None:
preprocess_cfg['std'] = std
return preprocess_cfg
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module = getattr(model, 'visual', model)
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
def get_model_tokenize_cfg(model):
module = getattr(model, 'text', model)
cfg = {}
context_length = getattr(module, 'context_length', None)
if context_length is not None:
cfg['context_length'] = context_length
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
================================================
FILE: inf_clip/models/coca_arch.py
================================================
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .clip_arch import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
if not isinstance(token_id, torch.Tensor):
if isinstance(token_id, int):
token_id = [token_id]
token_id = torch.tensor(token_id, device=device)
return token_id
class CoCa(nn.Module):
arch_type: torch.jit.Final[str] = 'coca'
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
self.context_length = multimodal_cfg.context_length
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def _encode_image(self, images, normalize: bool = True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent
def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
output_labels: bool = True,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
if text is None:
return {"image_features": image_latent, "image_embs": image_embs}
text_latent, token_embs = self._encode_text(text)
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]
logits = self.text_decoder(image_embs, token_embs)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = image.device
with torch.no_grad():
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(
image,
x,
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if all(stopping_criteria(out, None)):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
================================================
FILE: inf_clip/models/hf_configs.py
================================================
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/bert
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
},
"pooler": "cls_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/m2m_100
"m2m_100": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "encoder_attention_heads",
"layers": "encoder_layers",
},
"pooler": "cls_pooler",
},
}
================================================
FILE: inf_clip/models/hf_model.py
================================================
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
from contextlib import nullcontext
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import no_init_weights
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
@register_pooler
class ClsLastHiddenStatePooler(nn.Module):
"""CLS token pooling
NOTE: this is equivalent to ClsPooler above with use_pooler_output=False
"""
def __init__(self):
super().__init__()
self.cls_token_position = 0
def forward(self, x: BaseModelOutput, attention_mask: TensorType):
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
output_tokens: torch.jit.Final[bool]
def __init__(
self,
model_name_or_path: str,
output_dim: int,
pooler_type: str = None,
proj_type: str = None,
pretrained: bool = True,
output_tokens: bool = False,
):
super().__init__()
self.output_tokens = output_tokens
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
self.config = AutoConfig.from_pretrained(model_name_or_path)
# FIXME: Gradient Accumulation can't fully resume the dropout state, so we close dropout here.
# self.config.attention_probs_dropout_prob = 0.0 # Disable dropout
# self.config.hidden_dropout_prob = 0.0 # Disable dropout
# Enable sdpa attention
self.config._attn_implementation = 'sdpa'
# initialization of the model is really slow (https://github.com/huggingface/transformers/issues/9205#issuecomment-748741195)
# FIXME: To speed up the initialization of the model, we only load pretrained weights and
# disable the torch initialization of the weights.
if pretrained:
context = no_init_weights
else:
context = nullcontext
with context():
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True)
self.transformer = self.transformer.encoder
else:
self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True, add_pooling_layer=uses_transformer_pooler)
if pooler_type is None: # get default arch pooler
pooler_type = (arch_dict[self.config.model_type]["pooler"])
# FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models
self.vocab_size = getattr(self.config, 'vocab_size', 0)
self.context_length = getattr(self.config, 'max_position_embeddings', 0)
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj_type == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj_type == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
def forward(self, x: TensorType):
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
projected = self.proj(pooled_out)
seq_len = out.last_hidden_state.shape[1]
tokens = (
out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :]
if type(self.pooler) == ClsPooler
else out.last_hidden_state
)
if self.output_tokens:
return projected, tokens
return projected
def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def init_parameters(self):
pass
================================================
FILE: inf_clip/models/lit_arch.py
================================================
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as torch_checkpoint
from .clip_arch import _build_vision_tower, _build_text_tower
@dataclass
class LiTVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'learnable'
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'tok'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
@dataclass
class LiTTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
embed_cls: bool = False
pad_id: int = 0
no_causal_mask: bool = False # disable causal masking
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = 'mlp'
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
class LiT(nn.Module):
output_dict: torch.jit.Final[bool]
arch_type: torch.jit.Final[str] = 'lit'
def __init__(
self,
embed_dim: int,
vision_cfg: LiTVisionCfg,
text_cfg: LiTTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.embed_dim = embed_dim
self.lock_image_tower()
def get_embed_dim(self):
return self.embed_dim
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_trunk_image(self, image, normalize: bool = False):
trunk_features = self.visual.forward_trunk(image)
features = self.visual.forward_head(trunk_features)
return trunk_features, F.normalize(features, dim=-1) if normalize else features
def project_image(self, trunk_features, normalize: bool = False):
features = self.visual.head(trunk_features)
return trunk_features, F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
project_only: Optional[bool] = False,
):
if project_only:
image_trunk_features, image_features = self.project_image(image, normalize=True) if image is not None else None
else:
image_trunk_features, image_features = self.encode_trunk_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_trunk_features": image_trunk_features,
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_trunk_features, image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_trunk_features, image_features, text_features, self.logit_scale.exp()
================================================
FILE: inf_clip/models/loss.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from inf_cl import cal_flash_loss, cal_ring_loss, cal_inf_loss
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=False,
rank=0,
world_size=1,
use_horovod=False
):
if world_size == 1:
return image_features, text_features
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
if use_horovod:
assert hvd is not None, 'Please install horovod'
if gather_with_grad:
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
else:
with torch.no_grad():
all_image_features = hvd.allgather(image_features)
all_text_features = hvd.allgather(text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
else:
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
def all_reduce(tensor):
if not dist.is_available():
return tensor
else:
world_size = dist.get_world_size()
dist.all_reduce(tensor)
return tensor
class ClipLoss(nn.Module):
def __init__(
self,
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.local_loss = local_loss
self.gather_with_grad = gather_with_grad
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels, reduction='none') +
F.cross_entropy(logits_per_text, labels, reduction='none')
) / 2
scale_factor = (total_loss.shape[0] / image_features.shape[0])
total_loss = torch.mean(total_loss * scale_factor)
show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)
return {"contrastive_loss": total_loss, "show_loss": show_loss}
class DiscoClipLoss(nn.Module):
def __init__(
self,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
la
gitextract_u47kktxp/
├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── inf_cl/
│ ├── __init__.py
│ ├── flash.py
│ └── ring.py
├── inf_clip/
│ ├── __init__.py
│ ├── constants.py
│ ├── factory.py
│ ├── model_configs/
│ │ ├── EVA01-g-14-plus.json
│ │ ├── EVA01-g-14.json
│ │ ├── EVA02-B-16.json
│ │ ├── EVA02-E-14-plus.json
│ │ ├── EVA02-E-14.json
│ │ ├── EVA02-L-14-336.json
│ │ ├── EVA02-L-14.json
│ │ ├── LiT-B-16.json
│ │ ├── LiT-B-32.json
│ │ ├── LiT-L-16.json
│ │ ├── MobileCLIP-B.json
│ │ ├── MobileCLIP-S1.json
│ │ ├── MobileCLIP-S2.json
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN50x64.json
│ │ ├── ViT-B-16-SigLIP-256.json
│ │ ├── ViT-B-16-SigLIP-384.json
│ │ ├── ViT-B-16-SigLIP-512.json
│ │ ├── ViT-B-16-SigLIP-i18n-256.json
│ │ ├── ViT-B-16-SigLIP.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16-quickgelu.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-256.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14-378-quickgelu.json
│ │ ├── ViT-H-14-CLIPA-336.json
│ │ ├── ViT-H-14-CLIPA.json
│ │ ├── ViT-H-14-quickgelu.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14-CLIPA-336.json
│ │ ├── ViT-L-14-CLIPA.json
│ │ ├── ViT-L-14-quickgelu.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16-SigLIP-256.json
│ │ ├── ViT-L-16-SigLIP-384.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-M-16-alt.json
│ │ ├── ViT-M-16.json
│ │ ├── ViT-M-32-alt.json
│ │ ├── ViT-M-32.json
│ │ ├── ViT-S-16-alt.json
│ │ ├── ViT-S-16.json
│ │ ├── ViT-S-32-alt.json
│ │ ├── ViT-S-32.json
│ │ ├── ViT-SO400M-14-SigLIP-384.json
│ │ ├── ViT-SO400M-14-SigLIP.json
│ │ ├── ViT-bigG-14-CLIPA-336.json
│ │ ├── ViT-bigG-14-CLIPA.json
│ │ ├── ViT-bigG-14.json
│ │ ├── ViT-e-14.json
│ │ ├── ViT-g-14.json
│ │ ├── ViTamin-B-LTT.json
│ │ ├── ViTamin-B.json
│ │ ├── ViTamin-L-256.json
│ │ ├── ViTamin-L-336.json
│ │ ├── ViTamin-L.json
│ │ ├── ViTamin-L2-256.json
│ │ ├── ViTamin-L2-336.json
│ │ ├── ViTamin-L2.json
│ │ ├── ViTamin-S-LTT.json
│ │ ├── ViTamin-S.json
│ │ ├── ViTamin-XL-256.json
│ │ ├── ViTamin-XL-336.json
│ │ ├── ViTamin-XL-384.json
│ │ ├── coca_ViT-B-32.json
│ │ ├── coca_ViT-L-14.json
│ │ ├── coca_base.json
│ │ ├── coca_roberta-ViT-B-32.json
│ │ ├── convnext_base.json
│ │ ├── convnext_base_w.json
│ │ ├── convnext_base_w_320.json
│ │ ├── convnext_large.json
│ │ ├── convnext_large_d.json
│ │ ├── convnext_large_d_320.json
│ │ ├── convnext_small.json
│ │ ├── convnext_tiny.json
│ │ ├── convnext_xlarge.json
│ │ ├── convnext_xxlarge.json
│ │ ├── convnext_xxlarge_320.json
│ │ ├── mt5-base-ViT-B-32.json
│ │ ├── mt5-xl-ViT-H-14.json
│ │ ├── nllb-clip-base-siglip.json
│ │ ├── nllb-clip-base.json
│ │ ├── nllb-clip-large-siglip.json
│ │ ├── nllb-clip-large.json
│ │ ├── roberta-ViT-B-32.json
│ │ ├── swin_base_patch4_window7_224.json
│ │ ├── vit_medium_patch16_gap_256.json
│ │ ├── vit_relpos_medium_patch16_cls_224.json
│ │ ├── xlm-roberta-base-ViT-B-32.json
│ │ └── xlm-roberta-large-ViT-H-14.json
│ ├── models/
│ │ ├── clip_arch.py
│ │ ├── coca_arch.py
│ │ ├── hf_configs.py
│ │ ├── hf_model.py
│ │ ├── lit_arch.py
│ │ ├── loss.py
│ │ ├── modified_resnet.py
│ │ ├── pos_embed.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ └── transformer.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── train/
│ │ ├── data.py
│ │ ├── engine.py
│ │ ├── main.py
│ │ ├── optims.py
│ │ ├── params.py
│ │ └── utils.py
│ ├── utils.py
│ ├── zero_shot_classifier.py
│ └── zero_shot_metadata.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── benchmarks_eval.sh
│ ├── cc12m/
│ │ ├── clip_vit-b-32_bs32k.sh
│ │ ├── lit_vit-b-16_bs32k.sh
│ │ └── lit_vit-b-32_bs32k.sh
│ ├── cc3m/
│ │ ├── clip_r50_bs4k.sh
│ │ ├── clip_vit-b-32_bs16k.sh
│ │ └── lit_vit-b-32_bs16k.sh
│ ├── imagenet_eval.sh
│ └── laion400m/
│ ├── clip_vit-b-32_bs256k.sh
│ ├── lit_vit-b-16_bs256k.sh
│ ├── lit_vit-b-32_bs256k.sh
│ └── lit_vit-l-16_bs256k.sh
└── tests/
└── example.py
SYMBOL INDEX (447 symbols across 25 files)
FILE: inf_cl/flash.py
function _prob_fwd_kernel (line 12) | def _prob_fwd_kernel(
function _dq_prob_bwd_kernel (line 70) | def _dq_prob_bwd_kernel(
function _dk_prob_bwd_kernel (line 140) | def _dk_prob_bwd_kernel(
function _flash_prob_forward (line 204) | def _flash_prob_forward(q, k):
function _flash_prob_backward (line 242) | def _flash_prob_backward(q, k, lse, dlse):
class FlashProb (line 306) | class FlashProb(torch.autograd.Function):
method forward (line 309) | def forward(ctx, q, k):
method backward (line 316) | def backward(ctx, dlse):
function _cal_flash_loss (line 323) | def _cal_flash_loss(q, k, labels, head_dim=256):
function cal_flash_loss (line 337) | def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):
FILE: inf_cl/ring.py
class RingComm (line 17) | class RingComm:
method __init__ (line 19) | def __init__(self, process_group: dist.ProcessGroup):
method send_recv (line 33) | def send_recv(self, to_send, recv_tensor = None):
method commit (line 45) | def commit(self):
method wait (line 50) | def wait(self):
class GradientGather (line 59) | class GradientGather(torch.autograd.Function):
method forward (line 62) | def forward(ctx, x):
method backward (line 67) | def backward(ctx, dx):
class RingProb (line 72) | class RingProb(torch.autograd.Function):
method forward (line 75) | def forward(ctx, q, k, group):
method backward (line 109) | def backward(ctx, dlse):
class InfProb (line 154) | class InfProb(torch.autograd.Function):
method forward (line 157) | def forward(ctx, q, k, group):
method backward (line 190) | def backward(ctx, dlse):
function set_seed (line 231) | def set_seed(rank, seed=42):
function _cal_ring_loss (line 239) | def _cal_ring_loss(q, k, labels, head_dim=256):
function _cal_inf_loss (line 252) | def _cal_inf_loss(q, k, labels, head_dim=256):
function cal_ring_loss (line 265) | def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):
function cal_inf_loss (line 289) | def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):
FILE: inf_clip/factory.py
function _natural_key (line 32) | def _natural_key(string_):
function _rescan_model_configs (line 36) | def _rescan_model_configs():
function list_models (line 60) | def list_models():
function add_model_config (line 65) | def add_model_config(path):
function get_model_config (line 73) | def get_model_config(model_name):
function _get_hf_config (line 80) | def _get_hf_config(model_id, cache_dir=None):
function get_tokenizer (line 87) | def get_tokenizer(
function load_state_dict (line 131) | def load_state_dict(checkpoint_path: str, map_location='cpu'):
function load_checkpoint (line 146) | def load_checkpoint(
function create_model (line 183) | def create_model(
function create_loss (line 349) | def create_loss(args):
function create_model_and_transforms (line 410) | def create_model_and_transforms(
function create_model_from_pretrained (line 467) | def create_model_from_pretrained(
FILE: inf_clip/models/clip_arch.py
class CLIPVisionCfg (line 27) | class CLIPVisionCfg:
class CLIPTextCfg (line 58) | class CLIPTextCfg:
function get_cast_dtype (line 86) | def get_cast_dtype(precision: str):
function get_input_dtype (line 95) | def get_input_dtype(precision: str):
function _build_vision_tower (line 104) | def _build_vision_tower(
function _build_text_tower (line 173) | def _build_text_tower(
class CLIP (line 220) | class CLIP(nn.Module):
method __init__ (line 224) | def __init__(
method lock_image_tower (line 257) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
method set_grad_checkpointing (line 262) | def set_grad_checkpointing(self, enable=True):
method encode_image (line 266) | def encode_image(self, image, normalize: bool = False):
method encode_text (line 270) | def encode_text(self, text, normalize: bool = False):
method get_logits (line 287) | def get_logits(self, image, text):
method forward (line 296) | def forward(
class CustomTextCLIP (line 319) | class CustomTextCLIP(nn.Module):
method __init__ (line 323) | def __init__(
method lock_image_tower (line 346) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
method lock_text_tower (line 350) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
method set_grad_checkpointing (line 354) | def set_grad_checkpointing(self, enable=True):
method encode_image (line 358) | def encode_image(self, image, normalize: bool = False):
method encode_text (line 362) | def encode_text(self, text, normalize: bool = False):
method get_logits (line 366) | def get_logits(self, image, text):
method forward (line 375) | def forward(
function convert_weights_to_lp (line 398) | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
function convert_to_custom_text_state_dict (line 432) | def convert_to_custom_text_state_dict(state_dict: dict):
function build_model_from_openai_state_dict (line 450) | def build_model_from_openai_state_dict(
function trace_model (line 509) | def trace_model(model, batch_size=256, device=torch.device('cpu')):
function resize_pos_embed (line 525) | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', ...
function resize_text_pos_embed (line 559) | def resize_text_pos_embed(state_dict, model, interpolation: str = 'linea...
function get_model_preprocess_cfg (line 591) | def get_model_preprocess_cfg(model):
function set_model_preprocess_cfg (line 608) | def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
function get_model_tokenize_cfg (line 615) | def get_model_tokenize_cfg(model):
FILE: inf_clip/models/coca_arch.py
class MultimodalCfg (line 47) | class MultimodalCfg(CLIPTextCfg):
function _build_text_decoder_tower (line 55) | def _build_text_decoder_tower(
function _token_to_tensor (line 81) | def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
class CoCa (line 89) | class CoCa(nn.Module):
method __init__ (line 92) | def __init__(
method set_grad_checkpointing (line 146) | def set_grad_checkpointing(self, enable: bool = True):
method _encode_image (line 151) | def _encode_image(self, images, normalize: bool = True):
method _encode_text (line 156) | def _encode_text(self, text, normalize: bool = True):
method encode_image (line 161) | def encode_image(self, images, normalize: bool = True):
method encode_text (line 165) | def encode_text(self, text, normalize: bool = True):
method forward (line 169) | def forward(
method generate (line 204) | def generate(
method _generate_beamsearch (line 331) | def _generate_beamsearch(
function prepare_inputs_for_generation (line 481) | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **...
FILE: inf_clip/models/hf_model.py
class BaseModelOutput (line 22) | class BaseModelOutput:
class PretrainedConfig (line 26) | class PretrainedConfig:
function _camel2snake (line 33) | def _camel2snake(s):
function register_pooler (line 41) | def register_pooler(cls):
class MeanPooler (line 48) | class MeanPooler(nn.Module):
method forward (line 51) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
class MaxPooler (line 57) | class MaxPooler(nn.Module):
method forward (line 60) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
class ClsPooler (line 66) | class ClsPooler(nn.Module):
method __init__ (line 69) | def __init__(self, use_pooler_output=True):
method forward (line 74) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
class ClsLastHiddenStatePooler (line 85) | class ClsLastHiddenStatePooler(nn.Module):
method __init__ (line 90) | def __init__(self):
method forward (line 94) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
class HFTextEncoder (line 98) | class HFTextEncoder(nn.Module):
method __init__ (line 102) | def __init__(
method forward (line 166) | def forward(self, x: TensorType):
method lock (line 183) | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
method set_grad_checkpointing (line 201) | def set_grad_checkpointing(self, enable=True):
method init_parameters (line 204) | def init_parameters(self):
FILE: inf_clip/models/lit_arch.py
class LiTVisionCfg (line 14) | class LiTVisionCfg:
class LiTTextCfg (line 45) | class LiTTextCfg:
class LiT (line 73) | class LiT(nn.Module):
method __init__ (line 77) | def __init__(
method get_embed_dim (line 103) | def get_embed_dim(self):
method lock_image_tower (line 106) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
method lock_text_tower (line 110) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
method set_grad_checkpointing (line 114) | def set_grad_checkpointing(self, enable=True):
method encode_image (line 118) | def encode_image(self, image, normalize: bool = False):
method encode_trunk_image (line 122) | def encode_trunk_image(self, image, normalize: bool = False):
method project_image (line 127) | def project_image(self, trunk_features, normalize: bool = False):
method encode_text (line 131) | def encode_text(self, text, normalize: bool = False):
method get_logits (line 135) | def get_logits(self, image, text):
method forward (line 144) | def forward(
FILE: inf_clip/models/loss.py
function gather_features (line 21) | def gather_features(
function all_reduce (line 70) | def all_reduce(tensor):
class ClipLoss (line 79) | class ClipLoss(nn.Module):
method __init__ (line 81) | def __init__(
method get_ground_truth (line 102) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
method get_logits (line 115) | def get_logits(self, image_features, text_features, logit_scale):
method forward (line 133) | def forward(self, image_features, text_features, logit_scale):
class DiscoClipLoss (line 152) | class DiscoClipLoss(nn.Module):
method __init__ (line 154) | def __init__(
method get_ground_truth (line 169) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
method forward (line 180) | def forward(self, image_features, text_features, logit_scale):
class FlashClipLoss (line 204) | class FlashClipLoss(nn.Module):
method __init__ (line 206) | def __init__(
method get_ground_truth (line 221) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
method forward (line 231) | def forward(self, image_features, text_features, logit_scale):
class RingClipLoss (line 251) | class RingClipLoss(nn.Module):
method __init__ (line 253) | def __init__(
method forward (line 267) | def forward(self, image_features, text_features, logit_scale):
class InfClipLoss (line 291) | class InfClipLoss(nn.Module):
method __init__ (line 293) | def __init__(
method forward (line 307) | def forward(self, image_features, text_features, logit_scale):
class CoCaLoss (line 384) | class CoCaLoss(ClipLoss):
method __init__ (line 385) | def __init__(
method forward (line 410) | def forward(self, image_features, text_features, logits, labels, logit...
class DistillClipLoss (line 430) | class DistillClipLoss(ClipLoss):
method dist_loss (line 432) | def dist_loss(self, teacher_logits, student_logits):
method forward (line 435) | def forward(
function neighbour_exchange (line 469) | def neighbour_exchange(from_rank, to_rank, tensor, group=None):
function neighbour_exchange_bidir (line 489) | def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tens...
class NeighbourExchange (line 522) | class NeighbourExchange(torch.autograd.Function):
method forward (line 524) | def forward(ctx, from_rank, to_rank, group, tensor):
method backward (line 531) | def backward(ctx, grad_output):
function neighbour_exchange_with_grad (line 535) | def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
class NeighbourExchangeBidir (line 539) | class NeighbourExchangeBidir(torch.autograd.Function):
method forward (line 541) | def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_...
method backward (line 548) | def backward(ctx, *grad_outputs):
function neighbour_exchange_bidir_with_grad (line 553) | def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_...
class SigLipLoss (line 557) | class SigLipLoss(nn.Module):
method __init__ (line 567) | def __init__(
method get_ground_truth (line 587) | def get_ground_truth(self, device, dtype, num_logits, negative_only=Fa...
method get_logits (line 593) | def get_logits(self, image_features, text_features, logit_scale, logit...
method _loss (line 599) | def _loss(self, image_features, text_features, logit_scale, logit_bias...
method forward (line 610) | def forward(self, image_features, text_features, logit_scale, logit_bi...
FILE: inf_clip/models/modified_resnet.py
class Bottleneck (line 10) | class Bottleneck(nn.Module):
method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
method forward (line 42) | def forward(self, x: torch.Tensor):
class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module):
method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
method forward (line 68) | def forward(self, x):
class ModifiedResNet (line 95) | class ModifiedResNet(nn.Module):
method __init__ (line 103) | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
method _make_layer (line 132) | def _make_layer(self, planes, blocks, stride=1):
method init_parameters (line 141) | def init_parameters(self):
method lock (line 154) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
method set_grad_checkpointing (line 162) | def set_grad_checkpointing(self, enable=True):
method stem (line 166) | def stem(self, x):
method forward (line 173) | def forward(self, x):
FILE: inf_clip/models/pos_embed.py
function get_2d_sincos_pos_embed (line 20) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
function get_2d_sincos_pos_embed_from_grid (line 38) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
function interpolate_pos_embed (line 75) | def interpolate_pos_embed(model, checkpoint_model):
FILE: inf_clip/models/timm_model.py
class TimmModel (line 28) | class TimmModel(nn.Module):
method __init__ (line 32) | def __init__(
method lock (line 110) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
method set_grad_checkpointing (line 143) | def set_grad_checkpointing(self, enable=True):
method forward_trunk (line 149) | def forward_trunk(self, x):
method forward_head (line 152) | def forward_head(self, x):
method forward (line 155) | def forward(self, x):
FILE: inf_clip/models/tokenizer.py
function default_bpe (line 27) | def default_bpe():
function bytes_to_unicode (line 32) | def bytes_to_unicode():
function get_pairs (line 54) | def get_pairs(word):
function basic_clean (line 66) | def basic_clean(text):
function whitespace_clean (line 72) | def whitespace_clean(text):
function _clean_canonicalize (line 78) | def _clean_canonicalize(x):
function _clean_lower (line 83) | def _clean_lower(x):
function _clean_whitespace (line 88) | def _clean_whitespace(x):
function get_clean_fn (line 93) | def get_clean_fn(type: str):
function canonicalize_text (line 104) | def canonicalize_text(
class SimpleTokenizer (line 133) | class SimpleTokenizer(object):
method __init__ (line 134) | def __init__(
method bpe (line 172) | def bpe(self, token):
method encode (line 213) | def encode(self, text):
method decode (line 221) | def decode(self, tokens):
method __call__ (line 226) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...
function decode (line 271) | def decode(output_ids: torch.Tensor):
function tokenize (line 276) | def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT...
function random_mask_tokenize (line 280) | def random_mask_tokenize(
function simple_mask_tokenize (line 309) | def simple_mask_tokenize(
function syntax_mask_tokenize (line 331) | def syntax_mask_tokenize(
function get_reduction_mask_fn (line 390) | def get_reduction_mask_fn(type: str):
class HFTokenizer (line 403) | class HFTokenizer:
method __init__ (line 406) | def __init__(
method save_pretrained (line 426) | def save_pretrained(self, dest):
method __call__ (line 429) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...
method set_language (line 456) | def set_language(self, src_lang):
class SigLipTokenizer (line 463) | class SigLipTokenizer:
method __init__ (line 473) | def __init__(
method save_pretrained (line 497) | def save_pretrained(self, dest):
method __call__ (line 500) | def __call__(self, texts: Union[str, List[str]], context_length: Optio...
FILE: inf_clip/models/transform.py
class PreprocessCfg (line 17) | class PreprocessCfg:
method __post_init__ (line 26) | def __post_init__(self):
method num_channels (line 30) | def num_channels(self):
method input_size (line 34) | def input_size(self):
function merge_preprocess_dict (line 40) | def merge_preprocess_dict(
function merge_preprocess_kwargs (line 57) | def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
class AugmentationCfg (line 62) | class AugmentationCfg:
function _setup_size (line 75) | def _setup_size(size, error_msg):
class ResizeKeepRatio (line 88) | class ResizeKeepRatio:
method __init__ (line 94) | def __init__(
method get_params (line 116) | def get_params(
method __call__ (line 144) | def __call__(self, img):
method __repr__ (line 160) | def __repr__(self):
function center_crop_or_pad (line 167) | def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0...
class CenterCropOrPad (line 207) | class CenterCropOrPad(torch.nn.Module):
method __init__ (line 219) | def __init__(self, size, fill=0):
method forward (line 224) | def forward(self, img):
method __repr__ (line 234) | def __repr__(self) -> str:
function _convert_to_rgb (line 238) | def _convert_to_rgb(image):
class color_jitter (line 242) | class color_jitter(object):
method __init__ (line 246) | def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., ...
method __call__ (line 251) | def __call__(self, img):
class gray_scale (line 258) | class gray_scale(object):
method __init__ (line 262) | def __init__(self, p=0.2):
method __call__ (line 267) | def __call__(self, img):
function image_transform (line 274) | def image_transform(
function image_transform_v2 (line 393) | def image_transform_v2(
FILE: inf_clip/models/transformer.py
class LayerNormFp32 (line 15) | class LayerNormFp32(nn.LayerNorm):
method forward (line 18) | def forward(self, x: torch.Tensor):
class LayerNorm (line 24) | class LayerNorm(nn.LayerNorm):
method forward (line 27) | def forward(self, x: torch.Tensor):
class QuickGELU (line 33) | class QuickGELU(nn.Module):
method forward (line 35) | def forward(self, x: torch.Tensor):
class LayerScale (line 39) | class LayerScale(nn.Module):
method __init__ (line 40) | def __init__(self, dim, init_values=1e-5, inplace=False):
method forward (line 45) | def forward(self, x):
class PatchDropout (line 49) | class PatchDropout(nn.Module):
method __init__ (line 54) | def __init__(self, prob, exclude_first_token=True):
method forward (line 60) | def forward(self, x):
class Attention (line 89) | class Attention(nn.Module):
method __init__ (line 90) | def __init__(
method forward (line 132) | def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
class AttentionalPooler (line 187) | class AttentionalPooler(nn.Module):
method __init__ (line 188) | def __init__(
method forward (line 202) | def forward(self, x: torch.Tensor):
class ResidualAttentionBlock (line 210) | class ResidualAttentionBlock(nn.Module):
method __init__ (line 211) | def __init__(
method attention (line 239) | def attention(
method forward (line 254) | def forward(
class CustomResidualAttentionBlock (line 268) | class CustomResidualAttentionBlock(nn.Module):
method __init__ (line 269) | def __init__(
method get_reference_weight (line 306) | def get_reference_weight(self):
method forward (line 309) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
function _expand_token (line 315) | def _expand_token(token, batch_size: int):
class Transformer (line 319) | class Transformer(nn.Module):
method __init__ (line 320) | def __init__(
method get_cast_dtype (line 350) | def get_cast_dtype(self) -> torch.dtype:
method forward (line 355) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
class CustomTransformer (line 369) | class CustomTransformer(nn.Module):
method __init__ (line 371) | def __init__(
method get_cast_dtype (line 412) | def get_cast_dtype(self) -> torch.dtype:
method forward (line 418) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
class VisionTransformer (line 434) | class VisionTransformer(nn.Module):
method __init__ (line 437) | def __init__(
method lock (line 541) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
method init_parameters (line 574) | def init_parameters(self):
method set_grad_checkpointing (line 595) | def set_grad_checkpointing(self, enable=True):
method _global_pool (line 598) | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.T...
method forward (line 608) | def forward(self, x: torch.Tensor):
function text_global_pool (line 653) | def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: ...
class TextTransformer (line 668) | class TextTransformer(nn.Module):
method __init__ (line 671) | def __init__(
method init_parameters (line 731) | def init_parameters(self):
method set_grad_checkpointing (line 755) | def set_grad_checkpointing(self, enable=True):
method build_causal_mask (line 758) | def build_causal_mask(self):
method build_cls_mask (line 766) | def build_cls_mask(self, text, cast_dtype: torch.dtype):
method forward (line 775) | def forward(self, text):
class MultimodalTransformer (line 812) | class MultimodalTransformer(Transformer):
method __init__ (line 813) | def __init__(
method init_parameters (line 856) | def init_parameters(self):
method build_attention_mask (line 874) | def build_attention_mask(self):
method forward (line 882) | def forward(self, image_embs, text_embs):
method set_grad_checkpointing (line 907) | def set_grad_checkpointing(self, enable=True):
FILE: inf_clip/openai.py
function list_openai_models (line 20) | def list_openai_models() -> List[str]:
function load_openai_model (line 25) | def load_openai_model(
FILE: inf_clip/pretrained.py
function _pcfg (line 34) | def _pcfg(url='', hf_hub='', **kwargs):
function _slpcfg (line 47) | def _slpcfg(url='', hf_hub='', **kwargs):
function _apcfg (line 60) | def _apcfg(url='', hf_hub='', **kwargs):
function _mccfg (line 73) | def _mccfg(url='', hf_hub='', **kwargs):
function _clean_tag (line 519) | def _clean_tag(tag: str):
function list_pretrained (line 524) | def list_pretrained(as_str: bool = False):
function list_pretrained_models_by_tag (line 531) | def list_pretrained_models_by_tag(tag: str):
function list_pretrained_tags_by_model (line 541) | def list_pretrained_tags_by_model(model: str):
function is_pretrained_cfg (line 549) | def is_pretrained_cfg(model: str, tag: str):
function get_pretrained_cfg (line 555) | def get_pretrained_cfg(model: str, tag: str):
function get_pretrained_url (line 562) | def get_pretrained_url(model: str, tag: str):
function download_pretrained_from_url (line 567) | def download_pretrained_from_url(
function has_hf_hub (line 613) | def has_hf_hub(necessary=False):
function download_pretrained_from_hf (line 621) | def download_pretrained_from_hf(
function download_pretrained (line 632) | def download_pretrained(
function load_big_vision_weights (line 664) | def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
function convert_mobile_clip_state_dict (line 793) | def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fa...
function convert_state_dict (line 834) | def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
FILE: inf_clip/train/data.py
class CsvDataset (line 24) | class CsvDataset(Dataset):
method __init__ (line 25) | def __init__(self, input_filename, transforms, img_key, caption_key, s...
method __len__ (line 36) | def __len__(self):
method __getitem__ (line 39) | def __getitem__(self, idx):
class SharedEpoch (line 45) | class SharedEpoch:
method __init__ (line 46) | def __init__(self, epoch: int = 0):
method set_value (line 49) | def set_value(self, epoch):
method get_value (line 52) | def get_value(self):
class DataInfo (line 57) | class DataInfo:
method set_epoch (line 62) | def set_epoch(self, epoch):
function expand_urls (line 69) | def expand_urls(urls, weights=None):
function get_dataset_size (line 91) | def get_dataset_size(shards):
function get_imagenet (line 113) | def get_imagenet(args, preprocess_fns, split):
function count_samples (line 160) | def count_samples(dataloader):
function filter_no_caption_or_no_image (line 170) | def filter_no_caption_or_no_image(sample):
function log_and_continue (line 176) | def log_and_continue(exn):
function group_by_keys_nothrow (line 182) | def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes...
function tarfile_to_samples_nothrow (line 214) | def tarfile_to_samples_nothrow(src, handler=log_and_continue):
function pytorch_worker_seed (line 222) | def pytorch_worker_seed(increment=0):
function json_fetch (line 236) | def json_fetch(data, key='caption'):
class detshuffle2 (line 262) | class detshuffle2(wds.PipelineStage):
method __init__ (line 263) | def __init__(
method run (line 275) | def run(self, src):
class ResampledShards2 (line 294) | class ResampledShards2(IterableDataset):
method __init__ (line 297) | def __init__(
method __iter__ (line 324) | def __iter__(self):
function get_wds_dataset (line 348) | def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False...
function get_csv_dataset (line 468) | def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=No...
class SyntheticDataset (line 498) | class SyntheticDataset(Dataset):
method __init__ (line 500) | def __init__(
method __len__ (line 516) | def __len__(self):
method __getitem__ (line 519) | def __getitem__(self, idx):
function get_synthetic_dataset (line 525) | def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokeni...
function get_dataset_fn (line 548) | def get_dataset_fn(data_path, dataset_type):
function get_data (line 568) | def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
FILE: inf_clip/train/engine.py
function accuracy (line 28) | def accuracy(output, target, topk=(1,)):
function get_clip_metrics (line 34) | def get_clip_metrics(image_features, text_features, logit_scale):
function maybe_compute_generative_loss (line 54) | def maybe_compute_generative_loss(model_out):
function get_memory (line 61) | def get_memory():
function seconds_to_hms (line 70) | def seconds_to_hms(seconds):
function cal_grad_norm (line 77) | def cal_grad_norm(model):
function assign_learning_rate (line 87) | def assign_learning_rate(optimizer, new_lr):
function _warmup_lr (line 92) | def _warmup_lr(base_lr, warmup_length, step):
function const_lr (line 96) | def const_lr(optimizer, base_lr, warmup_length, steps):
function const_lr_cooldown (line 107) | def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown...
function cosine_lr (line 126) | def cosine_lr(optimizer, base_lr, warmup_length, steps):
function postprocess_clip_output (line 139) | def postprocess_clip_output(model_out):
function unwrap_model (line 147) | def unwrap_model(model):
function backward (line 154) | def backward(total_loss, scaler):
class AverageMeter (line 161) | class AverageMeter(object):
method __init__ (line 164) | def __init__(self):
method reset (line 167) | def reset(self):
method update (line 173) | def update(self, val, n=1):
class GradientAccum (line 180) | class GradientAccum:
method __init__ (line 182) | def __init__(self, model, loss, scaler, autocast, input_dtype, device):
method clear (line 203) | def clear(self):
method clear_state (line 208) | def clear_state(self):
method accum_inference (line 216) | def accum_inference(self, images, texts):
method accum_forward_backward (line 246) | def accum_forward_backward(self):
class GradientCache (line 292) | class GradientCache:
method __init__ (line 294) | def __init__(self, model, loss, scaler, autocast, input_dtype, device):
method clear (line 315) | def clear(self):
method clear_state (line 320) | def clear_state(self):
method forward_backward (line 327) | def forward_backward(self, images, texts):
method accum_inference (line 345) | def accum_inference(self, images, texts):
method accum_forward_backward (line 376) | def accum_forward_backward(self):
function train_one_epoch (line 432) | def train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer...
function evaluate (line 573) | def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
function zero_shot_run (line 682) | def zero_shot_run(model, classifier, dataloader, args):
function zero_shot_eval (line 709) | def zero_shot_eval(model, data, epoch, args, tokenizer=None):
FILE: inf_clip/train/main.py
function random_seed (line 42) | def random_seed(seed=42, rank=0):
function natural_key (line 50) | def natural_key(string_):
function copy_codebase (line 55) | def copy_codebase(args):
function prepare_logging (line 72) | def prepare_logging(args):
function get_latest_checkpoint (line 127) | def get_latest_checkpoint(path: str, remote : bool):
function prepare_resuming (line 143) | def prepare_resuming(args):
function prepare_remote_sync (line 178) | def prepare_remote_sync(args):
function prepare_model (line 205) | def prepare_model(args, device):
function prepare_optimizer_scaler (line 308) | def prepare_optimizer_scaler(args, model):
function prepare_scheduler (line 360) | def prepare_scheduler(args, optimizer, num_batches):
function main (line 383) | def main(args):
FILE: inf_clip/train/optims.py
class ScalingViTAdafactor (line 8) | class ScalingViTAdafactor(Optimizer):
method __init__ (line 18) | def __init__(
method _get_lr (line 52) | def _get_lr(param_group, param_state):
method _get_options (line 63) | def _get_options(param_group, param_shape):
method _rms (line 69) | def _rms(tensor):
method _approx_sq_grad (line 73) | def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
method step (line 81) | def step(self, closure=None):
class Lion (line 179) | class Lion(Optimizer):
method __init__ (line 184) | def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
method step (line 206) | def step(self, closure=None):
FILE: inf_clip/train/params.py
function get_default_params (line 9) | def get_default_params(model_name):
class ParseKwargs (line 18) | class ParseKwargs(argparse.Action):
method __call__ (line 19) | def __call__(self, parser, namespace, values, option_string=None):
function parse_args (line 30) | def parse_args(args):
function create_deepspeed_config (line 507) | def create_deepspeed_config(args):
FILE: inf_clip/train/utils.py
function setup_logging (line 19) | def setup_logging(log_file, level, include_host=False):
function remote_sync_s3 (line 43) | def remote_sync_s3(local_dir, remote_dir):
function remote_sync_fsspec (line 54) | def remote_sync_fsspec(local_dir, remote_dir):
function remote_sync (line 79) | def remote_sync(local_dir, remote_dir, protocol):
function keep_running_remote_sync (line 90) | def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
function start_sync_process (line 96) | def start_sync_process(sync_every, local_dir, remote_dir, protocol):
function pt_save (line 102) | def pt_save(pt_obj, file_path):
function pt_load (line 108) | def pt_load(file_path, map_location=None):
function check_exists (line 117) | def check_exists(file_path):
function get_autocast (line 126) | def get_autocast(precision):
function is_global_master (line 140) | def is_global_master(args):
function is_local_master (line 144) | def is_local_master(args):
function is_master (line 148) | def is_master(args, local=False):
function is_using_horovod (line 152) | def is_using_horovod():
function is_using_distributed (line 163) | def is_using_distributed():
function world_info_from_env (line 171) | def world_info_from_env():
function init_distributed_device (line 191) | def init_distributed_device(args):
function broadcast_object (line 254) | def broadcast_object(args, obj, src=0):
function all_gather_object (line 267) | def all_gather_object(args, obj, dst=0):
FILE: inf_clip/utils.py
function freeze_batch_norm_2d (line 9) | def freeze_batch_norm_2d(module, module_match={}, name=''):
function _ntuple (line 49) | def _ntuple(n):
function replace_linear (line 65) | def replace_linear(model, linear_replacement, include_modules=['c_fc', '...
function convert_int8_model_to_inference_mode (line 84) | def convert_int8_model_to_inference_mode(model):
FILE: inf_clip/zero_shot_classifier.py
function batched (line 9) | def batched(iterable, n):
function build_zero_shot_classifier (line 21) | def build_zero_shot_classifier(
function build_zero_shot_classifier_legacy (line 71) | def build_zero_shot_classifier_legacy(
FILE: tests/example.py
function create_cl_tensors (line 9) | def create_cl_tensors(rank, world_size):
Condensed preview — 152 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (516K chars).
[
{
"path": ".gitattributes",
"chars": 61,
"preview": "*.py linguist-language=python\n*.ipynb linguist-documentation\n"
},
{
"path": ".gitignore",
"chars": 2033,
"preview": "**/logs/\n**/wandb/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Di"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 10782,
"preview": "<p align=\"center\">\n <img src=\"https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad\" width="
},
{
"path": "inf_cl/__init__.py",
"chars": 80,
"preview": "from .flash import cal_flash_loss\nfrom .ring import cal_ring_loss, cal_inf_loss"
},
{
"path": "inf_cl/flash.py",
"chars": 12883,
"preview": "import math\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nimport triton\nimport triton.language as tl"
},
{
"path": "inf_cl/ring.py",
"chars": 12506,
"preview": "import os\nimport math\nimport random\n\nimport torch\nimport torch.distributed as dist\nimport torch.distributed.nn as dist_n"
},
{
"path": "inf_clip/__init__.py",
"chars": 1269,
"preview": "from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\n\nfrom .factory import create_model, create_model_and_tran"
},
{
"path": "inf_clip/constants.py",
"chars": 256,
"preview": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\nIMAG"
},
{
"path": "inf_clip/factory.py",
"chars": 19619,
"preview": "import json\nimport logging\nimport os\nimport re\nfrom copy import deepcopy\nfrom dataclasses import asdict\nfrom pathlib imp"
},
{
"path": "inf_clip/model_configs/EVA01-g-14-plus.json",
"chars": 401,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva_giant_patch14_22"
},
{
"path": "inf_clip/model_configs/EVA01-g-14.json",
"chars": 400,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva_giant_patch14_22"
},
{
"path": "inf_clip/model_configs/EVA02-B-16.json",
"chars": 404,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva02_base_patch16_cl"
},
{
"path": "inf_clip/model_configs/EVA02-E-14-plus.json",
"chars": 411,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva02_enormous_patch"
},
{
"path": "inf_clip/model_configs/EVA02-E-14.json",
"chars": 411,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva02_enormous_patch"
},
{
"path": "inf_clip/model_configs/EVA02-L-14-336.json",
"chars": 406,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 336,\n \"timm_model_name\": \"eva02_large_patch14_c"
},
{
"path": "inf_clip/model_configs/EVA02-L-14.json",
"chars": 406,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\": \"eva02_large_patch14_c"
},
{
"path": "inf_clip/model_configs/LiT-B-16.json",
"chars": 483,
"preview": "{\n \"arch\": \"LiT-B-16\",\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\""
},
{
"path": "inf_clip/model_configs/LiT-B-32.json",
"chars": 483,
"preview": "{\n \"arch\": \"LiT-B-32\",\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name\""
},
{
"path": "inf_clip/model_configs/LiT-L-16.json",
"chars": 487,
"preview": "{\n \"arch\": \"LiT-L-16\",\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"timm_model_name"
},
{
"path": "inf_clip/model_configs/MobileCLIP-B.json",
"chars": 483,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"vit_base_mci_224\",\n \"timm_model_pretraine"
},
{
"path": "inf_clip/model_configs/MobileCLIP-S1.json",
"chars": 476,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"fastvit_mci1\",\n \"timm_model_pretrained\": "
},
{
"path": "inf_clip/model_configs/MobileCLIP-S2.json",
"chars": 476,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"fastvit_mci2\",\n \"timm_model_pretrained\": "
},
{
"path": "inf_clip/model_configs/RN101-quickgelu.json",
"chars": 388,
"preview": "{\n \"embed_dim\": 512,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": [\n "
},
{
"path": "inf_clip/model_configs/RN101.json",
"chars": 364,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": [\n 3,\n 4,"
},
{
"path": "inf_clip/model_configs/RN50-quickgelu.json",
"chars": 389,
"preview": "{\n \"embed_dim\": 1024,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": [\n "
},
{
"path": "inf_clip/model_configs/RN50.json",
"chars": 364,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": [\n 3,\n 4"
},
{
"path": "inf_clip/model_configs/RN50x16.json",
"chars": 365,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 384,\n \"layers\": [\n 6,\n 8,"
},
{
"path": "inf_clip/model_configs/RN50x4.json",
"chars": 365,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"image_size\": 288,\n \"layers\": [\n 4,\n 6,"
},
{
"path": "inf_clip/model_configs/RN50x64.json",
"chars": 370,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 448,\n \"layers\": [\n 3,\n 1"
},
{
"path": "inf_clip/model_configs/ViT-B-16-SigLIP-256.json",
"chars": 710,
"preview": "{\n \"embed_dim\": 768,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": 2"
},
{
"path": "inf_clip/model_configs/ViT-B-16-SigLIP-384.json",
"chars": 710,
"preview": "{\n \"embed_dim\": 768,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": 3"
},
{
"path": "inf_clip/model_configs/ViT-B-16-SigLIP-512.json",
"chars": 710,
"preview": "{\n \"embed_dim\": 768,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": 5"
},
{
"path": "inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json",
"chars": 720,
"preview": "{\n \"embed_dim\": 768,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": 2"
},
{
"path": "inf_clip/model_configs/ViT-B-16-SigLIP.json",
"chars": 710,
"preview": "{\n \"embed_dim\": 768,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": 2"
},
{
"path": "inf_clip/model_configs/ViT-B-16-plus-240.json",
"chars": 295,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"image_size\": 240,\n \"layers\": 12,\n \"width\": 896,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-16-plus.json",
"chars": 295,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 896,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-16-quickgelu.json",
"chars": 318,
"preview": "{\n \"embed_dim\": 512,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-16.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-32-256.json",
"chars": 295,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 256,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-32-plus-256.json",
"chars": 295,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"image_size\": 256,\n \"layers\": 12,\n \"width\": 896,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-32-quickgelu.json",
"chars": 318,
"preview": "{\n \"embed_dim\": 512,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n "
},
{
"path": "inf_clip/model_configs/ViT-B-32.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-14-378-quickgelu.json",
"chars": 348,
"preview": "{\n \"embed_dim\": 1024,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 378,\n \"layers\": 32,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-14-CLIPA-336.json",
"chars": 604,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 336,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-14-CLIPA.json",
"chars": 604,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-14-quickgelu.json",
"chars": 348,
"preview": "{\n \"embed_dim\": 1024,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-14.json",
"chars": 324,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/ViT-H-16.json",
"chars": 324,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14-280.json",
"chars": 296,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 280,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14-336.json",
"chars": 296,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 336,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14-CLIPA-336.json",
"chars": 576,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 336,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14-CLIPA.json",
"chars": 576,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14-quickgelu.json",
"chars": 320,
"preview": "{\n \"embed_dim\": 768,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 24,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-14.json",
"chars": 296,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-16-320.json",
"chars": 296,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 320,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-L-16-SigLIP-256.json",
"chars": 713,
"preview": "{\n \"embed_dim\": 1024,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": "
},
{
"path": "inf_clip/model_configs/ViT-L-16-SigLIP-384.json",
"chars": 713,
"preview": "{\n \"embed_dim\": 1024,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": "
},
{
"path": "inf_clip/model_configs/ViT-L-16.json",
"chars": 296,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/ViT-M-16-alt.json",
"chars": 325,
"preview": "{\n \"embed_dim\": 384,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 512,\n "
},
{
"path": "inf_clip/model_configs/ViT-M-16.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 512,\n "
},
{
"path": "inf_clip/model_configs/ViT-M-32-alt.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 384,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 512,\n "
},
{
"path": "inf_clip/model_configs/ViT-M-32.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 512,\n "
},
{
"path": "inf_clip/model_configs/ViT-S-16-alt.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 256,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 384,\n "
},
{
"path": "inf_clip/model_configs/ViT-S-16.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 384,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 384,\n "
},
{
"path": "inf_clip/model_configs/ViT-S-32-alt.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 256,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 384,\n "
},
{
"path": "inf_clip/model_configs/ViT-S-32.json",
"chars": 294,
"preview": "{\n \"embed_dim\": 384,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 384,\n "
},
{
"path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json",
"chars": 743,
"preview": "{\n \"embed_dim\": 1152,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": "
},
{
"path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP.json",
"chars": 743,
"preview": "{\n \"embed_dim\": 1152,\n \"init_logit_bias\": -10,\n \"custom_text\": true,\n \"vision_cfg\": {\n \"image_size\": "
},
{
"path": "inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json",
"chars": 634,
"preview": "{\n \"embed_dim\": 1280,\n \"vision_cfg\": {\n \"image_size\": 336,\n \"layers\": 48,\n \"width\": 1664,\n "
},
{
"path": "inf_clip/model_configs/ViT-bigG-14-CLIPA.json",
"chars": 634,
"preview": "{\n \"embed_dim\": 1280,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 48,\n \"width\": 1664,\n "
},
{
"path": "inf_clip/model_configs/ViT-bigG-14.json",
"chars": 354,
"preview": "{\n \"embed_dim\": 1280,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 48,\n \"width\": 1664,\n "
},
{
"path": "inf_clip/model_configs/ViT-e-14.json",
"chars": 354,
"preview": "{\n \"embed_dim\": 1280,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 56,\n \"width\": 1792,\n "
},
{
"path": "inf_clip/model_configs/ViT-g-14.json",
"chars": 353,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 40,\n \"width\": 1408,\n "
},
{
"path": "inf_clip/model_configs/ViTamin-B-LTT.json",
"chars": 426,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_base_224\",\n \"timm_model_pretrained\": "
},
{
"path": "inf_clip/model_configs/ViTamin-B.json",
"chars": 425,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_base_224\",\n \"timm_model_pretrained\": "
},
{
"path": "inf_clip/model_configs/ViTamin-L-256.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large_256\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/ViTamin-L-336.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large_336\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/ViTamin-L.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large_224\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/ViTamin-L2-256.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large2_256\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/ViTamin-L2-336.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large2_336\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/ViTamin-L2.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_large2_224\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/ViTamin-S-LTT.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_small_224\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/ViTamin-S.json",
"chars": 426,
"preview": "{\n \"embed_dim\": 384,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_small_224\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/ViTamin-XL-256.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1152,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_xlarge_256\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/ViTamin-XL-336.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1152,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_xlarge_336\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/ViTamin-XL-384.json",
"chars": 430,
"preview": "{\n \"embed_dim\": 1152,\n \"vision_cfg\": {\n \"timm_model_name\": \"vitamin_xlarge_384\",\n \"timm_model_pretrained"
},
{
"path": "inf_clip/model_configs/coca_ViT-B-32.json",
"chars": 659,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/coca_ViT-L-14.json",
"chars": 664,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 24,\n \"width\": 1024,\n "
},
{
"path": "inf_clip/model_configs/coca_base.json",
"chars": 669,
"preview": "{\n \"embed_dim\": 512,\n \"multimodal_cfg\": {\n \"width\": 768,\n \"context_length\": 76,\n \"vocab_size\""
},
{
"path": "inf_clip/model_configs/coca_roberta-ViT-B-32.json",
"chars": 525,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/convnext_base.json",
"chars": 421,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_base\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/convnext_base_w.json",
"chars": 422,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_base\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/convnext_base_w_320.json",
"chars": 422,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_base\",\n \"timm_model_pretrained\":"
},
{
"path": "inf_clip/model_configs/convnext_large.json",
"chars": 423,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_large\",\n \"timm_model_pretrained\""
},
{
"path": "inf_clip/model_configs/convnext_large_d.json",
"chars": 420,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_large\",\n \"timm_model_pretrained\""
},
{
"path": "inf_clip/model_configs/convnext_large_d_320.json",
"chars": 420,
"preview": "{\n \"embed_dim\": 768,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_large\",\n \"timm_model_pretrained\""
},
{
"path": "inf_clip/model_configs/convnext_small.json",
"chars": 422,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_small\",\n \"timm_model_pretrained\""
},
{
"path": "inf_clip/model_configs/convnext_tiny.json",
"chars": 422,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_tiny\",\n \"timm_model_pretrained\""
},
{
"path": "inf_clip/model_configs/convnext_xlarge.json",
"chars": 426,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_xlarge\",\n \"timm_model_pretraine"
},
{
"path": "inf_clip/model_configs/convnext_xxlarge.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_xxlarge\",\n \"timm_model_pretrain"
},
{
"path": "inf_clip/model_configs/convnext_xxlarge_320.json",
"chars": 427,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"timm_model_name\": \"convnext_xxlarge\",\n \"timm_model_pretrain"
},
{
"path": "inf_clip/model_configs/mt5-base-ViT-B-32.json",
"chars": 305,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/mt5-xl-ViT-H-14.json",
"chars": 329,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/nllb-clip-base-siglip.json",
"chars": 509,
"preview": "{\n \"embed_dim\": 768,\n \"custom_text\": true,\n \"init_logit_bias\": -10,\n \"vision_cfg\": {\n \"image_size\": 3"
},
{
"path": "inf_clip/model_configs/nllb-clip-base.json",
"chars": 371,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/nllb-clip-large-siglip.json",
"chars": 512,
"preview": "{\n \"embed_dim\": 1152,\n \"custom_text\": true,\n \"init_logit_bias\": -10,\n \"vision_cfg\": {\n \"image_size\": "
},
{
"path": "inf_clip/model_configs/nllb-clip-large.json",
"chars": 399,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/model_configs/roberta-ViT-B-32.json",
"chars": 323,
"preview": "{\n \"embed_dim\": 512,\n \"quick_gelu\": true,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n "
},
{
"path": "inf_clip/model_configs/swin_base_patch4_window7_224.json",
"chars": 380,
"preview": "{\n \"embed_dim\": 640,\n \"vision_cfg\": {\n \"timm_model_name\": \"swin_base_patch4_window7_224\",\n \"timm_mod"
},
{
"path": "inf_clip/model_configs/vit_medium_patch16_gap_256.json",
"chars": 377,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"vit_medium_patch16_gap_256\",\n \"timm_model"
},
{
"path": "inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json",
"chars": 384,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"timm_model_name\": \"vit_relpos_medium_patch16_cls_224\",\n \"tim"
},
{
"path": "inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json",
"chars": 307,
"preview": "{\n \"embed_dim\": 512,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 12,\n \"width\": 768,\n "
},
{
"path": "inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json",
"chars": 337,
"preview": "{\n \"embed_dim\": 1024,\n \"vision_cfg\": {\n \"image_size\": 224,\n \"layers\": 32,\n \"width\": 1280,\n "
},
{
"path": "inf_clip/models/clip_arch.py",
"chars": 24432,
"preview": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimpo"
},
{
"path": "inf_clip/models/coca_arch.py",
"chars": 18994,
"preview": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nf"
},
{
"path": "inf_clip/models/hf_configs.py",
"chars": 2422,
"preview": "# HF architecture dict:\narch_dict = {\n # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n \"robe"
},
{
"path": "inf_clip/models/hf_model.py",
"chars": 7597,
"preview": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for u"
},
{
"path": "inf_clip/models/lit_arch.py",
"chars": 7051,
"preview": "from dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\n"
},
{
"path": "inf_clip/models/loss.py",
"chars": 23890,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ntry:\n import torch.distributed.nn\n from t"
},
{
"path": "inf_clip/models/modified_resnet.py",
"chars": 7027,
"preview": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom ..util"
},
{
"path": "inf_clip/models/pos_embed.py",
"chars": 4044,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "inf_clip/models/timm_model.py",
"chars": 6090,
"preview": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower "
},
{
"path": "inf_clip/models/tokenizer.py",
"chars": 18171,
"preview": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\ni"
},
{
"path": "inf_clip/models/transform.py",
"chars": 14341,
"preview": "import numbers\nimport random\nimport warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, Lis"
},
{
"path": "inf_clip/models/transformer.py",
"chars": 34106,
"preview": "from collections import OrderedDict\nimport math\nfrom typing import Callable, List, Optional, Sequence, Tuple, Union\nfrom"
},
{
"path": "inf_clip/openai.py",
"chars": 3314,
"preview": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c"
},
{
"path": "inf_clip/pretrained.py",
"chars": 35037,
"preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nimp"
},
{
"path": "inf_clip/train/data.py",
"chars": 21240,
"preview": "import ast\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport braceexpand\nfrom dataclasses"
},
{
"path": "inf_clip/train/engine.py",
"chars": 27949,
"preview": "import json\nimport logging\nimport math\nimport os\nimport time\nfrom contextlib import nullcontext\n\nimport numpy as np\nimpo"
},
{
"path": "inf_clip/train/main.py",
"chars": 22222,
"preview": "import glob\nimport logging\nimport os\nimport re\nimport subprocess\nimport sys\nimport random\nimport time\nfrom functools imp"
},
{
"path": "inf_clip/train/optims.py",
"chars": 9273,
"preview": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.optim import Optimizer\n\n\nclass ScalingViTAdafactor(Optimizer):"
},
{
"path": "inf_clip/train/params.py",
"chars": 21537,
"preview": "import os\nimport argparse\nimport ast\nimport json\n\nfrom .utils import world_info_from_env\n\n\ndef get_default_params(model_"
},
{
"path": "inf_clip/train/utils.py",
"chars": 8893,
"preview": "import os\nimport time\nimport logging\nimport subprocess\nimport multiprocessing\n\nimport torch\nimport torch.distributed as "
},
{
"path": "inf_clip/utils.py",
"chars": 3472,
"preview": "from itertools import repeat\nimport collections.abc\n\nimport torch\nfrom torch import nn as nn\nfrom torchvision.ops.misc i"
},
{
"path": "inf_clip/zero_shot_classifier.py",
"chars": 4340,
"preview": "from functools import partial\nfrom itertools import islice\nfrom typing import Callable, List, Optional, Sequence, Union\n"
},
{
"path": "inf_clip/zero_shot_metadata.py",
"chars": 19245,
"preview": "\nOPENAI_IMAGENET_TEMPLATES = (\n lambda c: f'a bad photo of a {c}.',\n lambda c: f'a photo of many {c}.',\n lambda"
},
{
"path": "pyproject.toml",
"chars": 1394,
"preview": "[build-system]\nrequires = [\"pdm-backend\"]\nbuild-backend = \"pdm.backend\"\n\n[project]\nname = \"inf-cl\"\nversion = \"1.2\"\nautho"
},
{
"path": "requirements.txt",
"chars": 446,
"preview": "--extra-index-url https://download.pytorch.org/whl/cu118\n# basic dependencies\ntorch==2.2.0\ntorchvision==0.17.0\nnumpy==1."
},
{
"path": "scripts/benchmarks_eval.sh",
"chars": 293,
"preview": "clip_benchmark eval \\\n --model LiT-B-16 \\\n --pretrained work_dirs/epoch_8.pt \\\n --dataset datasets/imagenet.txt"
},
{
"path": "scripts/cc12m/clip_vit-b-32_bs32k.sh",
"chars": 1862,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/cc12m/lit_vit-b-16_bs32k.sh",
"chars": 1931,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/cc12m/lit_vit-b-32_bs32k.sh",
"chars": 1931,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/cc3m/clip_r50_bs4k.sh",
"chars": 1874,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/cc3m/clip_vit-b-32_bs16k.sh",
"chars": 1890,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/cc3m/lit_vit-b-32_bs16k.sh",
"chars": 1927,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/imagenet_eval.sh",
"chars": 171,
"preview": "torchrun --nproc_per_node 1 \\\n -m inf_cl_train.main \\\n --imagenet-val datasets/imagenet-1k/val \\\n --model ViT-B"
},
{
"path": "scripts/laion400m/clip_vit-b-32_bs256k.sh",
"chars": 1821,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/laion400m/lit_vit-b-16_bs256k.sh",
"chars": 1941,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/laion400m/lit_vit-b-32_bs256k.sh",
"chars": 1941,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "scripts/laion400m/lit_vit-l-16_bs256k.sh",
"chars": 1941,
"preview": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16"
},
{
"path": "tests/example.py",
"chars": 1496,
"preview": "import torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport numpy as np\n\nfrom inf_cl import cal"
}
]
About this extraction
This page contains the full source code of the DAMO-NLP-SG/Inf-CLIP GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 152 files (471.8 KB), approximately 130.5k tokens, and a symbol index with 447 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.