Repository: DAMO-NLP-SG/Inf-CLIP
Branch: main
Commit: d9f2833b3753
Files: 152
Total size: 471.8 KB
Directory structure:
gitextract_u47kktxp/
├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── inf_cl/
│ ├── __init__.py
│ ├── flash.py
│ └── ring.py
├── inf_clip/
│ ├── __init__.py
│ ├── constants.py
│ ├── factory.py
│ ├── model_configs/
│ │ ├── EVA01-g-14-plus.json
│ │ ├── EVA01-g-14.json
│ │ ├── EVA02-B-16.json
│ │ ├── EVA02-E-14-plus.json
│ │ ├── EVA02-E-14.json
│ │ ├── EVA02-L-14-336.json
│ │ ├── EVA02-L-14.json
│ │ ├── LiT-B-16.json
│ │ ├── LiT-B-32.json
│ │ ├── LiT-L-16.json
│ │ ├── MobileCLIP-B.json
│ │ ├── MobileCLIP-S1.json
│ │ ├── MobileCLIP-S2.json
│ │ ├── RN101-quickgelu.json
│ │ ├── RN101.json
│ │ ├── RN50-quickgelu.json
│ │ ├── RN50.json
│ │ ├── RN50x16.json
│ │ ├── RN50x4.json
│ │ ├── RN50x64.json
│ │ ├── ViT-B-16-SigLIP-256.json
│ │ ├── ViT-B-16-SigLIP-384.json
│ │ ├── ViT-B-16-SigLIP-512.json
│ │ ├── ViT-B-16-SigLIP-i18n-256.json
│ │ ├── ViT-B-16-SigLIP.json
│ │ ├── ViT-B-16-plus-240.json
│ │ ├── ViT-B-16-plus.json
│ │ ├── ViT-B-16-quickgelu.json
│ │ ├── ViT-B-16.json
│ │ ├── ViT-B-32-256.json
│ │ ├── ViT-B-32-plus-256.json
│ │ ├── ViT-B-32-quickgelu.json
│ │ ├── ViT-B-32.json
│ │ ├── ViT-H-14-378-quickgelu.json
│ │ ├── ViT-H-14-CLIPA-336.json
│ │ ├── ViT-H-14-CLIPA.json
│ │ ├── ViT-H-14-quickgelu.json
│ │ ├── ViT-H-14.json
│ │ ├── ViT-H-16.json
│ │ ├── ViT-L-14-280.json
│ │ ├── ViT-L-14-336.json
│ │ ├── ViT-L-14-CLIPA-336.json
│ │ ├── ViT-L-14-CLIPA.json
│ │ ├── ViT-L-14-quickgelu.json
│ │ ├── ViT-L-14.json
│ │ ├── ViT-L-16-320.json
│ │ ├── ViT-L-16-SigLIP-256.json
│ │ ├── ViT-L-16-SigLIP-384.json
│ │ ├── ViT-L-16.json
│ │ ├── ViT-M-16-alt.json
│ │ ├── ViT-M-16.json
│ │ ├── ViT-M-32-alt.json
│ │ ├── ViT-M-32.json
│ │ ├── ViT-S-16-alt.json
│ │ ├── ViT-S-16.json
│ │ ├── ViT-S-32-alt.json
│ │ ├── ViT-S-32.json
│ │ ├── ViT-SO400M-14-SigLIP-384.json
│ │ ├── ViT-SO400M-14-SigLIP.json
│ │ ├── ViT-bigG-14-CLIPA-336.json
│ │ ├── ViT-bigG-14-CLIPA.json
│ │ ├── ViT-bigG-14.json
│ │ ├── ViT-e-14.json
│ │ ├── ViT-g-14.json
│ │ ├── ViTamin-B-LTT.json
│ │ ├── ViTamin-B.json
│ │ ├── ViTamin-L-256.json
│ │ ├── ViTamin-L-336.json
│ │ ├── ViTamin-L.json
│ │ ├── ViTamin-L2-256.json
│ │ ├── ViTamin-L2-336.json
│ │ ├── ViTamin-L2.json
│ │ ├── ViTamin-S-LTT.json
│ │ ├── ViTamin-S.json
│ │ ├── ViTamin-XL-256.json
│ │ ├── ViTamin-XL-336.json
│ │ ├── ViTamin-XL-384.json
│ │ ├── coca_ViT-B-32.json
│ │ ├── coca_ViT-L-14.json
│ │ ├── coca_base.json
│ │ ├── coca_roberta-ViT-B-32.json
│ │ ├── convnext_base.json
│ │ ├── convnext_base_w.json
│ │ ├── convnext_base_w_320.json
│ │ ├── convnext_large.json
│ │ ├── convnext_large_d.json
│ │ ├── convnext_large_d_320.json
│ │ ├── convnext_small.json
│ │ ├── convnext_tiny.json
│ │ ├── convnext_xlarge.json
│ │ ├── convnext_xxlarge.json
│ │ ├── convnext_xxlarge_320.json
│ │ ├── mt5-base-ViT-B-32.json
│ │ ├── mt5-xl-ViT-H-14.json
│ │ ├── nllb-clip-base-siglip.json
│ │ ├── nllb-clip-base.json
│ │ ├── nllb-clip-large-siglip.json
│ │ ├── nllb-clip-large.json
│ │ ├── roberta-ViT-B-32.json
│ │ ├── swin_base_patch4_window7_224.json
│ │ ├── vit_medium_patch16_gap_256.json
│ │ ├── vit_relpos_medium_patch16_cls_224.json
│ │ ├── xlm-roberta-base-ViT-B-32.json
│ │ └── xlm-roberta-large-ViT-H-14.json
│ ├── models/
│ │ ├── clip_arch.py
│ │ ├── coca_arch.py
│ │ ├── hf_configs.py
│ │ ├── hf_model.py
│ │ ├── lit_arch.py
│ │ ├── loss.py
│ │ ├── modified_resnet.py
│ │ ├── pos_embed.py
│ │ ├── timm_model.py
│ │ ├── tokenizer.py
│ │ ├── transform.py
│ │ └── transformer.py
│ ├── openai.py
│ ├── pretrained.py
│ ├── train/
│ │ ├── data.py
│ │ ├── engine.py
│ │ ├── main.py
│ │ ├── optims.py
│ │ ├── params.py
│ │ └── utils.py
│ ├── utils.py
│ ├── zero_shot_classifier.py
│ └── zero_shot_metadata.py
├── pyproject.toml
├── requirements.txt
├── scripts/
│ ├── benchmarks_eval.sh
│ ├── cc12m/
│ │ ├── clip_vit-b-32_bs32k.sh
│ │ ├── lit_vit-b-16_bs32k.sh
│ │ └── lit_vit-b-32_bs32k.sh
│ ├── cc3m/
│ │ ├── clip_r50_bs4k.sh
│ │ ├── clip_vit-b-32_bs16k.sh
│ │ └── lit_vit-b-32_bs16k.sh
│ ├── imagenet_eval.sh
│ └── laion400m/
│ ├── clip_vit-b-32_bs256k.sh
│ ├── lit_vit-b-16_bs256k.sh
│ ├── lit_vit-b-32_bs256k.sh
│ └── lit_vit-l-16_bs256k.sh
└── tests/
└── example.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitattributes
================================================
*.py linguist-language=python
*.ipynb linguist-documentation
================================================
FILE: .gitignore
================================================
**/logs/
**/wandb/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
sync.sh
gpu1sync.sh
.idea
*.pdf
**/._*
**/*DS_*
**.jsonl
src/sbatch
src/misc
.vscode
src/debug
core.*
# Allow
!src/evaluation/misc/results_dbs/*
# log dirs
/work_dirs*/
/datasets/
# oss logs
/.ossutil*
/ossutil*
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏
[](https://arxiv.org/abs/2410.17243)
[](https://huggingface.co/papers/2410.17243)
[](https://pypi.org/project/inf-cl)
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/LICENSE)
[](https://hits.seeyoufarm.com)
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aopen+is%3Aissue)
[](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aissue+is%3Aclosed)
[](https://zhuanlan.zhihu.com/p/1681887214)
[](https://x.com/lixin4ever/status/1849669129613226457)

💡 Some other multimodal foundation model projects from our team may interest you ✨.
> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922)
> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing
[](https://github.com/DAMO-NLP-SG/VCD) [](https://github.com/DAMO-NLP-SG/VCD) [](https://arxiv.org/abs/2311.16922)
> [**VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs**](https://github.com/DAMO-NLP-SG/VideoLLaMA2)
> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing
[](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [](https://arxiv.org/abs/2406.07476)
> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787)
> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing
[](https://github.com/DAMO-NLP-SG/CMM) [](https://github.com/DAMO-NLP-SG/CMM) [](https://arxiv.org/abs/2410.12787)
## 📰 News
* **[2024.10.18]** Release training and evaluation codes of Inf-CLIP.

## 🛠️ Requirements and Installation
Basic Dependencies:
* Python >= 3.8
* Pytorch >= 2.0.0
* CUDA Version >= 11.8
[Remote] Install Inf-CL:
```bash
# remote installing
pip install inf_cl -i https://pypi.org/simple
```
[Local] Install Inf-CL:
```bash
pip install -e .
```
Install required packages:
```bash
git clone https://github.com/DAMO-NLP-SG/Inf-CLIP
cd Inf-CLIP
pip install -r requirements.txt
```
## ⭐ Features
`inf_cl` is the triton implementation of Inf-CL loss:
* [x] [Ring-CL (inf_cl/ring.py#L238)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L238)
* [x] [Inf-CL (inf_cl/ring.py#L251)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L251)
`inf_clip` is the CLIP training codebase with Inf-CL loss and other training features:
- [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](https://github.com/DAMO-NLP-SG/Inf-CLIP/inf_clip_train/train.py#L180)
- [x] [Gradient Cache (inf_clip/train/train.py#L292)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip_train/train.py#L292)
## 🔑 Usage
A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:
```
torchrun --nproc_per_node 2 tests/example.py
```
```python
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from inf_cl import cal_inf_loss
def create_cl_tensors(rank, world_size):
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
l = l.requires_grad_() # Logit scale
return q, k, l
if __name__ == "__main__":
# Assume that the distributed environment has been initialized
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
# Exampled by Image-Text Contrastive Learning, q is the global image features,
# k is the text features, and l is the logit scale.
q, k, l = create_cl_tensors(rank, world_size)
# labels are diagonal elements by default.
# labels = torch.arange(q.shape[0])
loss = cal_inf_loss(q, k, scale=l.exp())
print(loss)
```
## 🚀 Main Results
### Memory Cost

\* denotes adopting "data offload" strategy.
### Max Supported Batch Size

### Speed

### Batch Size Scaling

Training with larger data scale needs larger batch size.
## 🗝️ Training & Evaluation
### Quick Start
To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.
1. Training Data Structure:
```bash
Inf-CLIP
├── datasets
│ ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 0301.tar
│ ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
| | ├── 0000.tar
| | ├── 0001.tar
| | ├── ...
| | └── 1044.tar
│ ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
| | ├── 00000.tar
| | ├── 00001.tar
| | ├── ...
| | └── 41407.tar
```
2. Command:
```bash
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
```
3. Evaluation Data Structure:
```bash
Inf-CLIP
├── datasets
│ ├── imagenet-1k/ # download val_images.tar.gz of imagenet from https://huggingface.co/datasets/ILSVRC/imagenet-1k/tree/main/data
| | └── val/ # python datasets/reformat_imagenet.py
| | | ├── n01440764
| | | ├── n01443537
| | | ├── ...
| | | └── n15075141
│ ├── clip-benchmark/ # bash datasets/benchmarks_download.sh
| | ├── wds_mscoco_captions
| | ├── wds_flickr8k
| | ├── wds_flickr30k
| | ├── wds_imagenet1k
| | ├── wds_imagenetv2
| | ├── wds_imagenet_sketch
| | ├── wds_imagenet-a
| | ├── wds_imagenet-r
| | ├── wds_imagenet-o
| | └── wds_objectnet
```
4. Command:
```bash
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh
```
## 📑 Citation
If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:
```bibtex
@article{damovl2024infcl,
title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},
author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},
journal={arXiv preprint arXiv:2410.17243},
year={2024},
url={https://arxiv.org/abs/2410.12787}
}
```
## 👍 Acknowledgement
The codebase of Inf-CLIP is adapted from [**OpenCLIP**](https://github.com/mlfoundations/open_clip). We are also grateful for the following projects our Inf-CL arose from:
* [**OpenAI CLIP**](https://openai.com/index/clip/), [**img2dataset**](https://github.com/rom1504/img2dataset), [**CLIP-Benchmark**](https://github.com/LAION-AI/CLIP_benchmark).
* [**FlashAttention**](https://github.com/Dao-AILab/flash-attention), [**RingAttention**](https://github.com/haoliuhl/ringattention), [**RingFlashAttention**](https://github.com/zhuzilin/ring-flash-attention).
## 🔒 License
This project is released under the Apache 2.0 license as found in the LICENSE file.
The service is a research preview intended for **non-commercial use ONLY**, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.
================================================
FILE: inf_cl/__init__.py
================================================
from .flash import cal_flash_loss
from .ring import cal_ring_loss, cal_inf_loss
================================================
FILE: inf_cl/flash.py
================================================
import math
import torch
import torch.nn.functional as F
import numpy as np
import triton
import triton.language as tl
@triton.jit
def _prob_fwd_kernel(
Q,
K,
LSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# start index of sequence length
start_m = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
# initialize pointer to m and l
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
# loop over k, v and update accumulator
end_n = seqlen_k
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
m_ij = tl.maximum(tl.max(qk, 1), m_i)
p = tl.exp(qk - m_ij[:, None])
# Fix out of bound access
p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0)
# -- update statistics
lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1)
m_i = m_ij
lse_i = m_i + tl.log(lse_i)
# mask out the padded values
lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0)
tl.store(LSE + offs_m, lse_i)
@triton.jit
def _dq_prob_bwd_kernel(
Q,
K,
dQ,
LSE,
dLSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
start_m = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
dq_ptrs = dQ + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
# setting lse
lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0)
dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0)
# loop over k, v and update accumulator
end_n = seqlen_k
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute q grad ----
# NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch
# A solution for this problem
# Refer to issue: https://github.com/triton-lang/triton/issues/4574
# if allow_tf32:
k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1)
q_grad = tl.dot(qk_grad, k)
# Another solution for this problem
# Refer to https://github.com/triton-lang/triton/issues/376
# q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)
# -- store dq ----
dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)
tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q)
@triton.jit
def _dk_prob_bwd_kernel(
Q,
K,
dK,
LSE,
dLSE,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;"
# start index of sequence length
start_n = tl.program_id(0)
# initialize offsets
ndims = nheads * BLOCK_HEADDIM
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Initialize pointers to Q, K, V
q_ptrs = Q + ndims * offs_m[:, None]
k_ptrs = K + ndims * offs_n[:, None]
dk_ptrs = dK + ndims * offs_n[:, None]
# loop over q and update accumulator
end_m = seqlen_q
for start_m in range(0, end_m, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
# setting lse
lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)
dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute qk ----
qk += tl.dot(q, tl.trans(k))
qk_grad = tl.exp(qk - lse[:, None])
qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0)
qk_grad = qk_grad * dlse[:, None]
qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)
for off_h in range(nheads):
offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]
# -- fetch q and k of a single head ----
q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0)
k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
# -- compute k grad ----
q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1)
k_grad = tl.dot(tl.trans(qk_grad), q)
# k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32))
# -- store dk ----
dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)
tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k)
def _flash_prob_forward(q, k):
# shape constraints
seqlen_q, nheads, d = q.shape
seqlen_k, _, _ = k.shape
assert k.shape == (seqlen_k, nheads, d)
# assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype, "All tensors must have the same type"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
lse = torch.empty((seqlen_q_rounded), device=q.device, dtype=torch.float32)
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK_M = 64
BLOCK_N = 64
num_warps = 8
num_stages = 1
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
_prob_fwd_kernel[grid](
q,
k,
lse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
lse = lse[:seqlen_q]
return lse
def _flash_prob_backward(q, k, lse, dlse):
# shape constraints
seqlen_q, nheads, d = q.shape
seqlen_k, _, _ = k.shape
assert k.shape == (seqlen_k, nheads, d)
# assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype, "All tensors must have the same type"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
q = q.contiguous()
k = k.contiguous()
dlse = dlse.contiguous()
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
BLOCK_M = 64
BLOCK_N = 64
num_warps = 8
num_stages = 1
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1)
_dq_prob_bwd_kernel[grid](
q,
k,
dq,
lse,
dlse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
BLOCK_N = BLOCK_M
BLOCK_M = BLOCK_N
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), 1)
_dk_prob_bwd_kernel[grid](
q,
k,
dk,
lse,
dlse,
nheads,
seqlen_q,
seqlen_k,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
dq = dq[:seqlen_q]
dk = dk[:seqlen_k]
return dq, dk
class FlashProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k):
lse = _flash_prob_forward(q, k)
ctx.save_for_backward(q, k, lse)
return lse
@staticmethod
def backward(ctx, dlse):
q, k, lse = ctx.saved_tensors
dq, dk = _flash_prob_backward(q, k, lse, dlse)
return dq, dk
def _cal_flash_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
# NOTE: logits forward or backward should keep fp32 for better precision
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = FlashProb.apply(q, k)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):
if labels is None:
labels = torch.arange(q.shape[0], device=q.device)
if scale is None:
scale = 1.0
return _cal_flash_loss(scale * q, k, labels, head_dim)
if __name__ == '__main__':
import time
# Parameters
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q, num_heads * d_model), dtype=torch.float32, device="cuda") # Query
k = torch.rand((seq_length_k, num_heads * d_model), dtype=torch.float32, device="cuda") # Key
l = torch.ones([], device="cuda") * np.log(1 / 0.02); l.requires_grad = True
q = F.normalize(q, p=2, dim=-1); q.requires_grad = True
k = F.normalize(k, p=2, dim=-1); k.requires_grad = True
q1 = q.clone().detach().requires_grad_(True)
k1 = k.clone().detach().requires_grad_(True)
l1 = l.clone().detach().requires_grad_(True)
labels = torch.arange(seq_length_q).cuda()
for i in range(1000):
# A. torch gradient
start = time.time()
qk = torch.einsum("md,nd->mn", l.exp() * q, k)
loss = F.cross_entropy(qk, labels, reduction="mean")
loss.backward()
end = time.time()
# B. triton gradient
start1 = time.time()
loss1 = cal_flash_loss(q1, k1, labels, l1.exp())
loss1 = loss1.mean()
loss1.backward()
end1 = time.time()
print("========= Difference =========")
print(end - start, end1 - start1, l.grad, l1.grad)
print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))
q.grad = None; k.grad = None; l.grad = None
q1.grad = None; k1.grad = None; l1.grad = None
================================================
FILE: inf_cl/ring.py
================================================
import os
import math
import random
import torch
import torch.distributed as dist
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
import numpy as np
import triton
import triton.language as tl
from .flash import _flash_prob_forward, _flash_prob_backward, _cal_flash_loss
class RingComm:
def __init__(self, process_group: dist.ProcessGroup):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
# print(f'rank: {self.rank}, send_rank: {self.send_rank}, recv_rank: {self.recv_rank}')
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(self, to_send, recv_tensor = None):
if recv_tensor is None:
res = torch.empty_like(to_send)
else:
res = recv_tensor
send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
class GradientGather(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, dx):
dist.all_reduce(dx)
return dx
class RingProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, group):
rank = dist.get_rank()
k = k.contiguous()
comm = RingComm(group)
colle = [q, k]
lse = None
next_k = None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
comm.commit()
# vanilla lse
qk = torch.einsum("mhd,nhd->mn", q, k)
block_lse = torch.log(torch.exp(qk).sum(dim=-1))
if step == 0:
lse = block_lse
else:
lse = lse - F.logsigmoid(lse - block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
# this should be out_padded
colle.append(lse)
ctx.save_for_backward(*colle)
ctx.group = group
return lse
@staticmethod
def backward(ctx, dlse):
rank = dist.get_rank()
q, k, lse = ctx.saved_tensors
k_comm = RingComm(ctx.group)
d_k_comm = RingComm(ctx.group)
dq, dk = None, None
next_dk = None
block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)
next_dk, next_k = None, None
for step in range(k_comm.world_size):
if step + 1 != k_comm.world_size:
next_k = k_comm.send_recv(k)
k_comm.commit()
# vanilla gradient calculation
qk = torch.einsum("mhd,nhd->mn", q, k)
qk_grad = torch.exp(qk - lse[:, None]).float()
qk_grad = qk_grad * dlse[:, None]
block_dq_buffer = torch.einsum("mn,nhd->mhd", qk_grad, k.float())
block_dk_buffer = torch.einsum("nm,mhd->nhd", qk_grad.T, q.float())
if step == 0:
dq = block_dq_buffer
dk = block_dk_buffer
else:
dq += block_dq_buffer
d_k_comm.wait()
dk = block_dk_buffer + next_dk
if step + 1 != k_comm.world_size:
k_comm.wait()
k = next_k
next_dk = d_k_comm.send_recv(dk)
d_k_comm.commit()
d_k_comm.wait()
return dq, next_dk, None
class InfProb(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, group):
rank = dist.get_rank()
k = k.contiguous()
comm = RingComm(group)
colle = [q, k]
lse = None
next_k = None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
comm.commit()
# flash lse
block_lse = _flash_prob_forward(q, k)
if step == 0:
lse = block_lse
else:
lse = lse - F.logsigmoid(lse - block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
# this should be out_padded
colle.append(lse)
ctx.save_for_backward(*colle)
ctx.group = group
return lse
@staticmethod
def backward(ctx, dlse):
rank = dist.get_rank()
q, k, lse = ctx.saved_tensors
k_comm = RingComm(ctx.group)
d_k_comm = RingComm(ctx.group)
dq, dk = None, None
next_dk = None
block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)
next_dk, next_k = None, None
for step in range(k_comm.world_size):
if step + 1 != k_comm.world_size:
next_k = k_comm.send_recv(k)
k_comm.commit()
# flash gradient calculation
block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse)
if step == 0:
dq = block_dq_buffer
dk = block_dk_buffer
else:
dq += block_dq_buffer
d_k_comm.wait()
dk = block_dk_buffer + next_dk
if step + 1 != k_comm.world_size:
k_comm.wait()
k = next_k
next_dk = d_k_comm.send_recv(dk)
d_k_comm.commit()
d_k_comm.wait()
return dq, next_dk, None
def set_seed(rank, seed=42):
seed = rank + seed
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _cal_ring_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = RingProb.apply(q, k, None)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def _cal_inf_loss(q, k, labels, head_dim=256):
bq = q.shape[0]
bk = k.shape[0]
q = q.view(bq, -1, head_dim).float()
k = k.view(bk, -1, head_dim).float()
lse = InfProb.apply(q, k, None)
numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...])
loss = -numerator + lse
return loss
def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):
"""The triton implementation of the ring-cl.
Args:
q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.
"""
if labels is None:
labels = torch.arange(q.shape[0]).to(q.device)
if scale is None:
scale = 1.0
else:
scale = GradientGather.apply(scale)
if torch.distributed.is_initialized():
return _cal_ring_loss(scale * q, k, labels, head_dim).mean()
else:
return _cal_flash_loss(scale * q, k, labels, head_dim).mean()
def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):
"""The triton implementation of the inf-cl.
Args:
q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].
k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].
labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.
scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.
head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.
"""
if labels is None:
labels = torch.arange(q.shape[0]).to(q.device)
if scale is None:
scale = 1.0
else:
scale = GradientGather.apply(scale)
if torch.distributed.is_initialized():
return _cal_inf_loss(scale * q, k, labels, head_dim).mean()
else:
return _cal_flash_loss(scale * q, k, labels, head_dim).mean()
if __name__ == "__main__":
import time
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(f'cuda:{os.environ["LOCAL_RANK"]}')
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda")
l = torch.ones([], dtype=dtype, device="cuda") * np.log(1 / 0.07); l = l.requires_grad_() # Logit scale
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
q1 = q.clone().detach().requires_grad_()
k1 = k.clone().detach().requires_grad_()
l1 = l.clone().detach().requires_grad_()
for i in range(1000):
# A. local torch gradient
start = time.time()
# A.1. gather q, k
gathered_q = [torch.zeros_like(q) for _ in range(world_size)]
gathered_k = [torch.zeros_like(k) for _ in range(world_size)]
dist.all_gather(gathered_q, q)
dist.all_gather(gathered_k, k)
gathered_q[rank] = q
gathered_k[rank] = k
all_q = torch.cat(gathered_q, dim=0)
all_k = torch.cat(gathered_k, dim=0)
# A.2. calculating qk logits
qk = torch.einsum("md,nd->mn", l.exp() * all_q, all_k)
kq = qk.T
_labels = torch.arange(seq_length_q).to(q.device)
# A.3. calculating loss
loss_i2t = F.cross_entropy(qk, _labels, reduction="mean")
loss_t2i = F.cross_entropy(kq, _labels, reduction="mean")
# A.4. scaling loss to normal value
scale_factor = (all_q.shape[0] / q.shape[0])
loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor
loss.backward()
show_loss = loss.detach().clone()
dist.all_reduce(show_loss)
show_loss = show_loss / (world_size * scale_factor)
end = time.time()
dist.barrier()
# B. triton implementation
start1 = time.time()
# labels = torch.arange(seq_length_q // world_size).to(q.device)
loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp())
loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp())
loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5
loss1.backward()
end1 = time.time()
dist.barrier()
if rank == 0:
print(rank, end - start, end1 - start1, loss, show_loss, loss1)
print(l.grad, l1.grad, torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))
q.grad = None; k.grad = None; l.grad = None
q1.grad = None; k1.grad = None; l1.grad = None
================================================
FILE: inf_clip/__init__.py
================================================
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .models.tokenizer import SimpleTokenizer, tokenize, decode
from .models.transform import image_transform, AugmentationCfg
from .models.coca_arch import CoCa
from .models.clip_arch import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \
get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss
from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy
from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES
================================================
FILE: inf_clip/constants.py
================================================
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
INCEPTION_MEAN = (0.5, 0.5, 0.5)
INCEPTION_STD = (0.5, 0.5, 0.5)
================================================
FILE: inf_clip/factory.py
================================================
import json
import logging
import os
import re
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf, convert_state_dict
from .models.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
from .models.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .models.clip_arch import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .models.coca_arch import CoCa
from .models.lit_arch import LiT
from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, FlashClipLoss, RingClipLoss, InfClipLoss, DiscoClipLoss
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
def get_model_config(model_name):
if model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def _get_hf_config(model_id, cache_dir=None):
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
**kwargs,
):
if model_name.startswith(HF_HUB_PREFIX):
model_name = model_name[len(HF_HUB_PREFIX):]
try:
config = _get_hf_config(model_name)['model_cfg']
except Exception:
tokenizer = HFTokenizer(
model_name,
context_length=context_length or DEFAULT_CONTEXT_LENGTH,
**kwargs,
)
return tokenizer
else:
config = get_model_config(model_name)
assert config is not None, f"No valid model config found for {model_name}."
text_config = config.get('text_cfg', {})
if 'tokenizer_kwargs' in text_config:
tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)
else:
tokenizer_kwargs = kwargs
if context_length is None:
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
if 'hf_tokenizer_name' in text_config:
tokenizer = HFTokenizer(
text_config['hf_tokenizer_name'],
context_length=context_length,
**tokenizer_kwargs,
)
else:
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)
return tokenizer
def load_state_dict(checkpoint_path: str, map_location='cpu'):
checkpoint = torch.load(checkpoint_path, map_location=map_location)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.pretrained import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path)
# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)
# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def create_model(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = False,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
**model_kwargs,
):
force_preprocess_cfg = force_preprocess_cfg or {}
preprocess_cfg = asdict(PreprocessCfg())
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
if has_hf_hub_prefix:
model_id = model_name[len(HF_HUB_PREFIX):]
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
config = _get_hf_config(model_id, cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
model_cfg = config['model_cfg']
pretrained_hf = False # override, no need to load original HF text weights
else:
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
checkpoint_path = None
model_cfg = None
if isinstance(device, str):
device = torch.device(device)
if pretrained and pretrained.lower() == 'openai':
logging.info(f'Loading pretrained {model_name} from OpenAI.')
model = load_openai_model(
model_name,
precision=precision,
device=device,
cache_dir=cache_dir,
)
else:
model_cfg = model_cfg or get_model_config(model_name)
if model_cfg is not None:
logging.info(f'Loaded {model_name} model config.')
else:
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
raise RuntimeError(f'Model config for {model_name} not found.')
if force_quick_gelu:
# override for use of QuickGELU on non-OpenAI transformer models
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None and force_patch_dropout != False:
# override the default patch dropout value
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
if force_image_size is not None and force_image_size != False:
# override model config's image size
model_cfg["vision_cfg"]["image_size"] = force_image_size
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
if pretrained_image:
if is_timm_model:
# pretrained weight loading for timm models set via vision_cfg
model_cfg['vision_cfg']['timm_model_pretrained'] = True
else:
assert False, 'pretrained image towers currently only supported for timm models'
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
cast_dtype = get_cast_dtype(precision)
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_hf_model:
# load pretrained weights for HF text model IFF no CLIP weights being loaded
# NOTE: disable pretrained_hf arguments.
# model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
model_cfg['text_cfg']['hf_model_pretrained'] = model_cfg['text_cfg']['hf_model_pretrained']
# and not pretrained
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
model_arch = model_cfg.pop("arch", "CLIP")
if custom_text:
if "CLIP" in model_arch:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
elif "LiT" in model_arch:
model = LiT(**model_cfg, cast_dtype=cast_dtype)
elif "CoCa" in model_arch or "multimodal_cfg" in model_cfg:
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
from .transformer import LayerNormFp32
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
pretrained_loaded = False
if pretrained:
checkpoint_path = ''
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
elif os.path.exists(pretrained):
checkpoint_path = pretrained
if checkpoint_path:
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
load_checkpoint(model, checkpoint_path)
else:
error_str = (
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
logging.warning(error_str)
raise RuntimeError(error_str)
pretrained_loaded = True
elif has_hf_hub_prefix:
logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
load_checkpoint(model, checkpoint_path)
pretrained_loaded = True
if require_pretrained and not pretrained_loaded:
# callers of create_model_from_pretrained always expect pretrained weights
raise RuntimeError(
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
if output_dict and hasattr(model, "output_dict"):
model.output_dict = True
if jit:
model = torch.jit.script(model)
# set image preprocessing configuration in model attributes for convenience
if getattr(model.visual, 'image_size', None) is not None:
# use image_size set on model creation (via config or force_image_size arg)
force_preprocess_cfg['size'] = model.visual.image_size
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
return model
def create_loss(args):
if args.distill_model:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.siglip:
assert not args.horovod, "Horovod not currently supported for SigLip"
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
)
elif args.flashloss:
return FlashClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.ringloss:
return RingClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
elif args.infloss:
return InfClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
elif args.discoloss:
return DiscoClipLoss(
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_hf: bool = False,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
pretrained_image=pretrained_image,
pretrained_hf=pretrained_hf,
cache_dir=cache_dir,
output_dict=output_dict,
**model_kwargs,
)
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
preprocess_train = image_transform_v2(
pp_cfg,
is_train=True,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform_v2(
pp_cfg,
is_train=False,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
**model_kwargs,
):
force_preprocess_cfg = merge_preprocess_kwargs(
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
cache_dir=cache_dir,
require_pretrained=True,
**model_kwargs,
)
if not return_transform:
return model
preprocess = image_transform_v2(
PreprocessCfg(**model.visual.preprocess_cfg),
is_train=False,
)
return model, preprocess
================================================
FILE: inf_clip/model_configs/EVA01-g-14-plus.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA01-g-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva_giant_patch14_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-B-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_base_patch16_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-E-14-plus.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-E-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_enormous_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-L-14-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"timm_model_name": "eva02_large_patch14_clip_336",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/EVA02-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "eva02_large_patch14_clip_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/LiT-B-16.json
================================================
{
"arch": "LiT-B-16",
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch16_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-base-uncased",
"hf_model_name": "bert-base-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/LiT-B-32.json
================================================
{
"arch": "LiT-B-32",
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch32_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-base-uncased",
"hf_model_name": "bert-base-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/LiT-L-16.json
================================================
{
"arch": "LiT-L-16",
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_large_patch16_224",
"timm_model_pretrained": true,
"timm_pool": "token",
"timm_proj": "linear"
},
"text_cfg": {
"hf_tokenizer_name": "bert-large-uncased",
"hf_model_name": "bert-large-uncased",
"hf_model_pretrained": true,
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-B.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_base_mci_224",
"timm_model_pretrained": false,
"timm_pool": "token",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": false
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-S1.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci1",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/MobileCLIP-S2.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "fastvit_mci2",
"timm_model_pretrained": false,
"timm_pool": "avg",
"timm_proj": null,
"timm_drop": 0.0,
"timm_drop_path": 0.0,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"no_causal_mask": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/RN101-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN101.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
23,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": [
3,
4,
6,
3
],
"width": 64,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 384,
"layers": [
6,
8,
18,
8
],
"width": 96,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x4.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 288,
"layers": [
4,
6,
10,
6
],
"width": 80,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/RN50x64.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 448,
"layers": [
3,
15,
36,
10
],
"width": 128,
"patch_size": null
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-256.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_base_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-384.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_base_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-512.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 512,
"timm_model_name": "vit_base_patch16_siglip_512",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_base_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 250000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-SigLIP.json
================================================
{
"embed_dim": 768,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_base_patch16_siglip_224",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 768,
"heads": 12,
"layers": 12,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-plus-240.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 240,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-plus.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 896,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-256.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 256,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-plus-256.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"image_size": 256,
"layers": 12,
"width": 896,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32-quickgelu.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-378-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 378,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA-336.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 336,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1024,
"heads": 16,
"layers": 24,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-CLIPA.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1024,
"heads": 16,
"layers": 24,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14-quickgelu.json
================================================
{
"embed_dim": 1024,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-H-16.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-280.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 280,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 336,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-CLIPA.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 768,
"heads": 12,
"layers": 12,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14-quickgelu.json
================================================
{
"embed_dim": 768,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 320,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-256.json
================================================
{
"embed_dim": 1024,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 256,
"timm_model_name": "vit_large_patch16_siglip_256",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1024,
"heads": 16,
"layers": 24,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16-SigLIP-384.json
================================================
{
"embed_dim": 1024,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_large_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1024,
"heads": 16,
"layers": 24,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-L-16.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-16-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16,
"ls_init_value": 1e-4
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-16.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-32-alt.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-M-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 512,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-16-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-16.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 16
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-32-alt.json
================================================
{
"embed_dim": 256,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 256,
"heads": 4,
"layers": 10
}
}
================================================
FILE: inf_clip/model_configs/ViT-S-32.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 384,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json
================================================
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_so400m_patch14_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 64,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP.json
================================================
{
"embed_dim": 1152,
"init_logit_bias": -10,
"custom_text": true,
"vision_cfg": {
"image_size": 224,
"timm_model_name": "vit_so400m_patch14_siglip_224",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"context_length": 16,
"vocab_size": 32000,
"hf_tokenizer_name": "timm/ViT-B-16-SigLIP",
"tokenizer_kwargs": {
"clean": "canonicalize"
},
"width": 1152,
"heads": 16,
"layers": 27,
"mlp_ratio": 3.7362,
"no_causal_mask": true,
"proj_bias": true,
"pool_type": "last",
"norm_kwargs":{
"eps": 1e-6
}
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 336,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1280,
"heads": 20,
"layers": 32,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14,
"no_ln_pre": true,
"pool_type": "avg",
"final_ln_after_pool": true
},
"text_cfg": {
"context_length": 32,
"vocab_size": 32000,
"hf_tokenizer_name": "bert-base-uncased",
"tokenizer_kwargs": {
"strip_sep_token": true
},
"width": 1280,
"heads": 20,
"layers": 32,
"pool_type": "last",
"no_causal_mask": true
}
}
================================================
FILE: inf_clip/model_configs/ViT-bigG-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 48,
"width": 1664,
"head_width": 104,
"mlp_ratio": 4.9231,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 32
}
}
================================================
FILE: inf_clip/model_configs/ViT-e-14.json
================================================
{
"embed_dim": 1280,
"vision_cfg": {
"image_size": 224,
"layers": 56,
"width": 1792,
"head_width": 112,
"mlp_ratio": 8.5715,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1280,
"heads": 20,
"layers": 36
}
}
================================================
FILE: inf_clip/model_configs/ViT-g-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 40,
"width": 1408,
"head_width": 88,
"mlp_ratio": 4.3637,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/ViTamin-B-LTT.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-B.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vitamin_base_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L-256.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L-336.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_large_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2-256.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2-336.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-L2.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "vitamin_large2_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-S-LTT.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "vitamin_small_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-S.json
================================================
{
"embed_dim": 384,
"vision_cfg": {
"timm_model_name": "vitamin_small_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 384,
"heads": 6,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-256.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-336.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_336",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 336
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/ViTamin-XL-384.json
================================================
{
"embed_dim": 1152,
"vision_cfg": {
"timm_model_name": "vitamin_xlarge_384",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1152,
"heads": 16,
"layers": 27
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12,
"attn_pooler_heads": 8
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_ViT-L-14.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14,
"attentional_pool": true,
"attn_pooler_heads": 8,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"embed_cls": true,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12,
"attn_pooler_heads": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_base.json
================================================
{
"embed_dim": 512,
"multimodal_cfg": {
"width": 768,
"context_length": 76,
"vocab_size": 64000,
"mlp_ratio": 4,
"layers": 12,
"dim_head": 64,
"heads": 12,
"n_queries": 256,
"attn_pooler_heads": 8
},
"vision_cfg": {
"image_size": 288,
"layers": 12,
"width": 768,
"patch_size": 18,
"output_tokens": true
},
"text_cfg": {
"context_length": 76,
"vocab_size": 64000,
"layers": 12,
"heads": 12,
"width": 768,
"embed_cls": true,
"output_tokens": true
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/coca_roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32,
"output_tokens": true
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"hf_proj_type": "linear",
"width": 768,
"output_tokens": true
},
"multimodal_cfg": {
"context_length": 76,
"width": 768,
"heads": 8,
"layers": 12
},
"custom_text": true
}
================================================
FILE: inf_clip/model_configs/convnext_base.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_base_w.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_base_w_320.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "convnext_base",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_large.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_large_d.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: inf_clip/model_configs/convnext_large_d_320.json
================================================
{
"embed_dim": 768,
"vision_cfg": {
"timm_model_name": "convnext_large",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "mlp",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 16
}
}
================================================
FILE: inf_clip/model_configs/convnext_small.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "convnext_small",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_tiny.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_tiny",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/convnext_xlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 20
}
}
================================================
FILE: inf_clip/model_configs/convnext_xxlarge.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/convnext_xxlarge_320.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"timm_model_name": "convnext_xxlarge",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"timm_drop": 0.0,
"timm_drop_path": 0.1,
"image_size": 320
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 1024,
"heads": 16,
"layers": 24
}
}
================================================
FILE: inf_clip/model_configs/mt5-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "google/mt5-base",
"hf_tokenizer_name": "google/mt5-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/mt5-xl-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "google/mt5-xl",
"hf_tokenizer_name": "google/mt5-xl",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-base-siglip.json
================================================
{
"embed_dim": 768,
"custom_text": true,
"init_logit_bias": -10,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_base_patch16_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-600M",
"hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-base.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-600M",
"hf_tokenizer_name": "facebook/nllb-200-distilled-600M",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-large-siglip.json
================================================
{
"embed_dim": 1152,
"custom_text": true,
"init_logit_bias": -10,
"vision_cfg": {
"image_size": 384,
"timm_model_name": "vit_so400m_patch14_siglip_384",
"timm_model_pretrained": false,
"timm_pool": "map",
"timm_proj": "none"
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-1.3B",
"hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/nllb-clip-large.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "facebook/nllb-200-distilled-1.3B",
"hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B",
"hf_proj_type": "linear",
"hf_pooler_type": "cls_pooler"
}
}
================================================
FILE: inf_clip/model_configs/roberta-ViT-B-32.json
================================================
{
"embed_dim": 512,
"quick_gelu": true,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "roberta-base",
"hf_tokenizer_name": "roberta-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/swin_base_patch4_window7_224.json
================================================
{
"embed_dim": 640,
"vision_cfg": {
"timm_model_name": "swin_base_patch4_window7_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 640,
"heads": 10,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/vit_medium_patch16_gap_256.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_medium_patch16_gap_256",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 256
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"timm_model_name": "vit_relpos_medium_patch16_cls_224",
"timm_model_pretrained": false,
"timm_pool": "",
"timm_proj": "linear",
"image_size": 224
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
}
}
================================================
FILE: inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json
================================================
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"hf_model_name": "xlm-roberta-base",
"hf_tokenizer_name": "xlm-roberta-base",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json
================================================
{
"embed_dim": 1024,
"vision_cfg": {
"image_size": 224,
"layers": 32,
"width": 1280,
"head_width": 80,
"patch_size": 14
},
"text_cfg": {
"hf_model_name": "xlm-roberta-large",
"hf_tokenizer_name": "xlm-roberta-large",
"hf_pooler_type": "mean_pooler"
}
}
================================================
FILE: inf_clip/models/clip_arch.py
================================================
""" CLIP Model
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial
from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
text_global_pool
from ..utils import to_2tuple
@dataclass
class CLIPVisionCfg:
layers: Union[Tuple[int, int, int, int], int] = 12
width: int = 768
head_width: int = 64
mlp_ratio: float = 4.0
patch_size: int = 16
image_size: Union[Tuple[int, int], int] = 224
ls_init_value: Optional[float] = None # layer scale initial value
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)
attn_pooler_queries: int = 256 # n_queries for attentional pooler
attn_pooler_heads: int = 8 # n heads for attentional_pooling
no_ln_pre: bool = False # disable pre transformer LayerNorm
pos_embed_type: str = 'learnable'
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'tok'
output_tokens: bool = False
act_kwargs: Optional[dict] = None
norm_kwargs: Optional[dict] = None
timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
timm_proj_bias: bool = False # enable bias final projection
timm_drop: float = 0. # head dropout
timm_drop_path: Optional[float] = None # backbone stochastic depth
@dataclass
class CLIPTextCfg:
context_length: int = 77
vocab_size: int = 49408
hf_tokenizer_name: Optional[str] = None
tokenizer_kwargs: Optional[dict] = None
width: int = 512
heads: int = 8
layers: int = 12
mlp_ratio: float = 4.0
ls_init_value: Optional[float] = None # layer scale initial value
embed_cls: bool = False
pad_id: int = 0
no_causal_mask: bool = False # disable causal masking
final_ln_after_pool: bool = False # apply final LayerNorm after pooling
pool_type: str = 'argmax'
proj_bias: bool = False
output_tokens: bool = False
act_kwargs: dict = None
norm_kwargs: dict = None
# HuggingFace specific text tower config
hf_model_name: Optional[str] = None
hf_model_pretrained: bool = True
hf_proj_type: str = 'mlp'
hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == 'bf16':
cast_dtype = torch.bfloat16
elif precision == 'fp16':
cast_dtype = torch.float16
return cast_dtype
def get_input_dtype(precision: str):
input_dtype = None
if precision in ('bf16', 'pure_bf16'):
input_dtype = torch.bfloat16
elif precision in ('fp16', 'pure_fp16'):
input_dtype = torch.float16
return input_dtype
def _build_vision_tower(
embed_dim: int,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
if isinstance(vision_cfg, dict):
vision_cfg = CLIPVisionCfg(**vision_cfg)
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
# memory efficient in recent PyTorch releases (>= 1.10).
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
act_layer = QuickGELU if quick_gelu else nn.GELU
if vision_cfg.timm_model_name:
visual = TimmModel(
vision_cfg.timm_model_name,
pretrained=vision_cfg.timm_model_pretrained,
pool=vision_cfg.timm_pool,
proj=vision_cfg.timm_proj,
proj_bias=vision_cfg.timm_proj_bias,
drop=vision_cfg.timm_drop,
drop_path=vision_cfg.timm_drop_path,
patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,
embed_dim=embed_dim,
image_size=vision_cfg.image_size,
)
elif isinstance(vision_cfg.layers, (tuple, list)):
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
visual = ModifiedResNet(
layers=vision_cfg.layers,
output_dim=embed_dim,
heads=vision_heads,
image_size=vision_cfg.image_size,
width=vision_cfg.width,
)
else:
vision_heads = vision_cfg.width // vision_cfg.head_width
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if vision_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)
if vision_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **vision_cfg.act_kwargs)
visual = VisionTransformer(
image_size=vision_cfg.image_size,
patch_size=vision_cfg.patch_size,
width=vision_cfg.width,
layers=vision_cfg.layers,
heads=vision_heads,
mlp_ratio=vision_cfg.mlp_ratio,
ls_init_value=vision_cfg.ls_init_value,
patch_dropout=vision_cfg.patch_dropout,
attentional_pool=vision_cfg.attentional_pool,
attn_pooler_queries=vision_cfg.attn_pooler_queries,
attn_pooler_heads=vision_cfg.attn_pooler_heads,
pos_embed_type=vision_cfg.pos_embed_type,
no_ln_pre=vision_cfg.no_ln_pre,
final_ln_after_pool=vision_cfg.final_ln_after_pool,
pool_type=vision_cfg.pool_type,
output_tokens=vision_cfg.output_tokens,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return visual
def _build_text_tower(
embed_dim: int,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
if isinstance(text_cfg, dict):
text_cfg = CLIPTextCfg(**text_cfg)
if text_cfg.hf_model_name:
text = HFTextEncoder(
text_cfg.hf_model_name,
output_dim=embed_dim,
proj_type=text_cfg.hf_proj_type,
pooler_type=text_cfg.hf_pooler_type,
pretrained=text_cfg.hf_model_pretrained,
output_tokens=text_cfg.output_tokens,
)
else:
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
if text_cfg.norm_kwargs:
norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)
if text_cfg.act_kwargs is not None:
act_layer = partial(act_layer, **text_cfg.act_kwargs)
text = TextTransformer(
context_length=text_cfg.context_length,
vocab_size=text_cfg.vocab_size,
width=text_cfg.width,
heads=text_cfg.heads,
layers=text_cfg.layers,
mlp_ratio=text_cfg.mlp_ratio,
ls_init_value=text_cfg.ls_init_value,
output_dim=embed_dim,
embed_cls=text_cfg.embed_cls,
no_causal_mask=text_cfg.no_causal_mask,
pad_id=text_cfg.pad_id,
pool_type=text_cfg.pool_type,
proj_bias=text_cfg.proj_bias,
output_tokens=text_cfg.output_tokens,
act_layer=act_layer,
norm_layer=norm_layer,
)
return text
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
arch_type: torch.jit.Final[str] = 'clip'
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.transformer = text.transformer
self.context_length = text.context_length
self.vocab_size = text.vocab_size
self.token_embedding = text.token_embedding
self.positional_embedding = text.positional_embedding
self.ln_final = text.ln_final
self.text_projection = text.text_projection
self.text_pool_type = text.pool_type
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
cast_dtype = self.transformer.get_cast_dtype()
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.to(cast_dtype)
x = self.transformer(x, attn_mask=self.attn_mask)
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
x, _ = text_global_pool(x, text, self.text_pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
x = self.text_projection(x)
else:
x = x @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
class CustomTextCLIP(nn.Module):
output_dict: torch.jit.Final[bool]
arch_type: torch.jit.Final[str] = 'clip'
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
):
super().__init__()
self.output_dict = output_dict
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
self.context_length = self.text.context_length
self.vocab_size = self.text.vocab_size
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
self.text.lock(unlocked_layers, freeze_layer_norm)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
def encode_image(self, image, normalize: bool = False):
features = self.visual(image)
return F.normalize(features, dim=-1) if normalize else features
def encode_text(self, text, normalize: bool = False):
features = self.text(text)
return F.normalize(features, dim=-1) if normalize else features
def get_logits(self, image, text):
image_features = self.encode_image(image, normalize=True)
text_features = self.encode_text(text, normalize=True)
image_logits = self.logit_scale.exp() * image_features @ text_features.T
if self.logit_bias is not None:
image_logits += self.logit_bias
text_logits = image_logits.T
return image_logits, text_logits
def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None
if self.output_dict:
out_dict = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale.exp()
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict
if self.logit_bias is not None:
return image_features, text_features, self.logit_scale.exp(), self.logit_bias
return image_features, text_features, self.logit_scale.exp()
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
def _convert_weights(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype)
if isinstance(l, (nn.MultiheadAttention, Attention)):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.to(dtype)
if isinstance(l, (CLIP, TextTransformer)):
# convert text nn.Parameter projections
attr = getattr(l, "text_projection", None)
if attr is not None:
attr.data = attr.data.to(dtype)
if isinstance(l, VisionTransformer):
# convert vision nn.Parameter projections
attr = getattr(l, "proj", None)
if attr is not None:
attr.data = attr.data.to(dtype)
model.apply(_convert_weights)
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
if 'text_projection' in state_dict:
# old format state_dict, move text tower -> .text
new_state_dict = {}
for k, v in state_dict.items():
if any(k.startswith(p) for p in (
'text_projection',
'positional_embedding',
'token_embedding',
'transformer',
'ln_final',
)):
k = 'text.' + k
new_state_dict[k] = v
return new_state_dict
return state_dict
def build_model_from_openai_state_dict(
state_dict: dict,
quick_gelu=True,
cast_dtype=torch.float16,
):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_size = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_size = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
vision_cfg = CLIPVisionCfg(
layers=vision_layers,
width=vision_width,
patch_size=vision_patch_size,
image_size=image_size,
)
text_cfg = CLIPTextCfg(
context_length=context_length,
vocab_size=vocab_size,
width=transformer_width,
heads=transformer_heads,
layers=transformer_layers,
)
model = CLIP(
embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
cast_dtype=cast_dtype,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
model.load_state_dict(state_dict)
return model.eval()
def trace_model(model, batch_size=256, device=torch.device('cpu')):
model.eval()
image_size = model.visual.image_size
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
model = torch.jit.trace_module(
model,
inputs=dict(
forward=(example_images, example_text),
encode_text=(example_text,),
encode_image=(example_images,)
))
model.visual.image_size = image_size
return model
def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
# Rescale the grid of position embeddings when loading from state_dict
old_pos_embed = state_dict.get('visual.positional_embedding', None)
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
return
grid_size = to_2tuple(model.visual.grid_size)
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
if new_seq_len == old_pos_embed.shape[0]:
return
if extra_tokens:
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
else:
pos_emb_tok, pos_emb_img = None, old_pos_embed
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
pos_emb_img = F.interpolate(
pos_emb_img,
size=grid_size,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
if pos_emb_tok is not None:
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
else:
new_pos_embed = pos_emb_img
state_dict['visual.positional_embedding'] = new_pos_embed
def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):
old_pos_embed = state_dict.get('positional_embedding', None)
if old_pos_embed is None:
return
# FIXME add support for text cls_token
model_pos_embed = getattr(model, 'positional_embedding', None)
if model_pos_embed is None:
model_pos_embed = getattr(model.text, 'positional_embedding', None)
old_num_pos = old_pos_embed.shape[0]
old_width = old_pos_embed.shape[1]
num_pos = model_pos_embed.shape[0]
width = model_pos_embed.shape[1]
assert old_width == width, 'text pos_embed width changed!'
if old_num_pos == num_pos:
return
logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)
old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)
old_pos_embed = F.interpolate(
old_pos_embed,
size=num_pos,
mode=interpolation,
antialias=antialias,
align_corners=False,
)
old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]
new_pos_embed = old_pos_embed
state_dict['positional_embedding'] = new_pos_embed
def get_model_preprocess_cfg(model):
module = getattr(model, 'visual', model)
preprocess_cfg = getattr(module, 'preprocess_cfg', {})
if not preprocess_cfg:
# use separate legacy attributes if preprocess_cfg dict not found
size = getattr(module, 'image_size')
if size is not None:
preprocess_cfg['size'] = size
mean = getattr(module, 'image_mean', None)
if mean is not None:
preprocess_cfg['mean'] = mean
std = getattr(module, 'image_std', None)
if std is not None:
preprocess_cfg['std'] = std
return preprocess_cfg
def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):
module = getattr(model, 'visual', model)
module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat
module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat
module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict
def get_model_tokenize_cfg(model):
module = getattr(model, 'text', model)
cfg = {}
context_length = getattr(module, 'context_length', None)
if context_length is not None:
cfg['context_length'] = context_length
vocab_size = getattr(module, 'vocab_size', None)
if vocab_size is not None:
cfg['vocab_size'] = vocab_size
return cfg
================================================
FILE: inf_clip/models/coca_arch.py
================================================
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
)
from .clip_arch import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower
try:
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
@dataclass
class MultimodalCfg(CLIPTextCfg):
mlp_ratio: int = 4
dim_head: int = 64
heads: int = 8
n_queries: int = 256
attn_pooler_heads: int = 8
def _build_text_decoder_tower(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
)
return decoder
def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor:
if not isinstance(token_id, torch.Tensor):
if isinstance(token_id, int):
token_id = [token_id]
token_id = torch.tensor(token_id, device=device)
return token_id
class CoCa(nn.Module):
arch_type: torch.jit.Final[str] = 'coca'
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
vision_cfg: CLIPVisionCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
text_cfg.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
self.visual = _build_vision_tower(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
self.context_length = multimodal_cfg.context_length
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.visual.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def _encode_image(self, images, normalize: bool = True):
image_latent, tokens_embs = self.visual(images)
image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
return image_latent, tokens_embs
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, images, normalize: bool = True):
image_latent, _ = self._encode_image(images, normalize=normalize)
return image_latent
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent
def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
output_labels: bool = True,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
if text is None:
return {"image_features": image_latent, "image_embs": image_embs}
text_latent, token_embs = self._encode_text(text)
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]
logits = self.text_decoder(image_embs, token_embs)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
"logits": logits,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
return out_dict
def generate(
self,
image,
text=None,
seq_len=30,
max_seq_len=77,
temperature=1.,
generation_type="beam_search",
top_p=0.1, # keep tokens in the 1 - top_p quantile
top_k=1, # keeps the top_k most probable tokens
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False # if True output.shape == (batch_size, seq_len)
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = image.device
with torch.no_grad():
sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)
eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)
pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)
if generation_type == "beam_search":
output = self._generate_beamsearch(
image_inputs=image,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
)
if fixed_output_length and output.shape[1] < seq_len:
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
image_latent, image_embs = self._encode_image(image)
if text is None:
text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(
image,
x,
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if all(stopping_criteria(out, None)):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
image_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
):
device = image_inputs.device
batch_size = image_inputs.shape[0]
image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
image_latent, image_embs = self._encode_image(image_inputs)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
outputs = self(
model_inputs['images'],
model_inputs['text'],
image_latent=image_latent,
image_embs=image_embs,
output_labels=False,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"images": image_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
================================================
FILE: inf_clip/models/hf_configs.py
================================================
# HF architecture dict:
arch_dict = {
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
"roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
"xlm-roberta": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
"layer_attr": "layer",
"token_embeddings_attr": "embeddings"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
"mt5": {
"config_names": {
# unlimited seqlen
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
"context_length": "",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "num_heads",
"layers": "num_layers",
"layer_attr": "block",
"token_embeddings_attr": "embed_tokens"
},
"pooler": "mean_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/bert
"bert": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "hidden_size",
"heads": "num_attention_heads",
"layers": "num_hidden_layers",
},
"pooler": "cls_pooler",
},
# https://huggingface.co/docs/transformers/model_doc/m2m_100
"m2m_100": {
"config_names": {
"context_length": "max_position_embeddings",
"vocab_size": "vocab_size",
"width": "d_model",
"heads": "encoder_attention_heads",
"layers": "encoder_layers",
},
"pooler": "cls_pooler",
},
}
================================================
FILE: inf_clip/models/hf_model.py
================================================
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
from contextlib import nullcontext
import torch
import torch.nn as nn
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import no_init_weights
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(? torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
if self.world_size > 1 and self.local_loss:
labels = labels + num_logits * self.rank
if self.cache_labels:
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def get_logits(self, image_features, text_features, logit_scale):
if self.world_size > 1:
all_image_features, all_text_features = gather_features(
image_features, text_features,
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
if self.local_loss:
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
else:
logits_per_image = logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T
return logits_per_image, logits_per_text
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels, reduction='none') +
F.cross_entropy(logits_per_text, labels, reduction='none')
) / 2
scale_factor = (total_loss.shape[0] / image_features.shape[0])
total_loss = torch.mean(total_loss * scale_factor)
show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)
return {"contrastive_loss": total_loss, "show_loss": show_loss}
class DiscoClipLoss(nn.Module):
def __init__(
self,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
labels = labels + num_logits * self.rank
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
all_image_features, all_text_features = gather_features(
image_features, text_features,
True, True, self.rank, self.world_size, self.use_horovod)
logits_per_image = logit_scale * image_features @ all_text_features.T
logits_per_text = logit_scale * text_features @ all_image_features.T
labels = self.get_ground_truth(device, logits_per_image.shape[0])
total_loss = (
F.cross_entropy(logits_per_image, labels, reduction='none') +
F.cross_entropy(logits_per_text, labels, reduction='none')
) / 2
total_loss = torch.mean(total_loss)
show_loss = all_reduce(total_loss.detach().clone()) / self.world_size
return {"contrastive_loss": total_loss, "show_loss": show_loss}
class FlashClipLoss(nn.Module):
def __init__(
self,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, num_logits) -> torch.Tensor:
# calculated ground-truth and cache if enabled
if self.prev_num_logits != num_logits or device not in self.labels:
labels = torch.arange(num_logits, device=device, dtype=torch.long)
self.labels[device] = labels
self.prev_num_logits = num_logits
else:
labels = self.labels[device]
return labels
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
all_image_features, all_text_features = gather_features(
image_features, text_features,
False, False, self.rank, self.world_size, self.use_horovod)
labels = self.get_ground_truth(device, all_image_features.shape[0])
i2t_loss = _cal_flash_loss(logit_scale * all_image_features, all_text_features, labels)
t2i_loss = _cal_flash_loss(logit_scale * all_text_features, all_image_features, labels)
total_loss = (i2t_loss + t2i_loss) / 2
scale_factor = (total_loss.shape[0] / image_features.shape[0])
total_loss = torch.mean(total_loss * scale_factor)
show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)
return {"contrastive_loss": total_loss, "show_loss": show_loss}
class RingClipLoss(nn.Module):
def __init__(
self,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.labels = {}
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
q = image_features
k = text_features
l = logit_scale
if device in self.labels:
labels = self.labels[device]
else:
labels = torch.arange(q.shape[0], device=device, dtype=torch.long)
self.labels[device] = labels
i2t_loss = cal_ring_loss(q, k, labels, l)
t2i_loss = cal_ring_loss(k, q, labels, l)
total_loss = (i2t_loss + t2i_loss) / 2
total_loss = total_loss.mean()
show_loss = all_reduce(total_loss.detach().clone()) / self.world_size
return {"contrastive_loss": total_loss, "show_loss": show_loss}
class InfClipLoss(nn.Module):
def __init__(
self,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__()
self.rank = rank
self.world_size = world_size
self.use_horovod = use_horovod
# cache state
self.labels = {}
def forward(self, image_features, text_features, logit_scale):
device = image_features.device
q = image_features
k = text_features
l = logit_scale
if device in self.labels:
labels = self.labels[device]
else:
labels = torch.arange(q.shape[0], device=device, dtype=torch.long)
self.labels[device] = labels
i2t_loss = cal_inf_loss(q, k, labels, l)
t2i_loss = cal_inf_loss(k, q, labels, l)
total_loss = (i2t_loss + t2i_loss) / 2
total_loss = total_loss.mean()
show_loss = all_reduce(total_loss.detach().clone()) / self.world_size
return {"contrastive_loss": total_loss, "show_loss": show_loss}
# NOTE: debug code for checkint ring loss gradient
# rank = dist.get_rank()
# q = image_features.detach().clone().requires_grad_()
# k = text_features.detach().clone().requires_grad_()
# l = logit_scale.detach().clone().requires_grad_()
# all_q, all_k = gather_features(
# q, k,
# self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
# if rank == 0:
# import numpy as np
# np.save('all_q.npy', all_q.detach().cpu().numpy())
# np.save('all_k.npy', all_k.detach().cpu().numpy())
# labels = torch.arange(all_q.shape[0], device=device, dtype=torch.long)
# qk = l * all_q @ all_k.T
# lse = torch.logsumexp(qk, dim=1)
# # numerator = torch.einsum("md,md->m",l * all_q, all_k[labels])
# numerator = qk[torch.arange(qk.shape[0]), labels]
# i2t_loss = -numerator + lse
# lse = torch.logsumexp(qk.T, dim=1)
# # numerator = torch.einsum("md,md->m", l * all_k, all_q[labels])
# numerator = qk.T[labels, torch.arange(qk.shape[0])]
# t2i_loss = -numerator + lse
# # i2t_loss = F.cross_entropy(qk, labels, reduction='none')
# # t2i_loss = F.cross_entropy(qk.T, labels, reduction='none')
# total_loss = (i2t_loss + t2i_loss) / 2
# total_loss.sum().backward()
# q1 = image_features.detach().clone().requires_grad_()
# k1 = text_features.detach().clone().requires_grad_()
# l1 = logit_scale.detach().clone().requires_grad_()
# labels = torch.arange(q1.shape[0], device=device, dtype=torch.long)
# i2t_loss1 = _cal_ring_loss(l1 * q1, k1, labels)
# t2i_loss1 = _cal_ring_loss(l1 * k1, q1, labels)
# total_loss1 = (i2t_loss1 + t2i_loss1) / 2
# q1.retain_grad(); k1.retain_grad(); l1.retain_grad()
# total_loss1.sum().backward()
# if rank == 1:
# import numpy as np
# np.save('q_r1_grad.npy', q.grad.detach().cpu().numpy())
# np.save('q1_r1_grad.npy', q1.grad.detach().cpu().numpy())
# print(q.grad, q1.grad)
# print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)), torch.max(torch.abs(l.grad - l1.grad)))
# exit(0)
class CoCaLoss(ClipLoss):
def __init__(
self,
caption_loss_weight,
clip_loss_weight,
pad_id=0, # pad_token for open_clip custom tokenizer
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super().__init__(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod
)
self.clip_loss_weight = clip_loss_weight
self.caption_loss_weight = caption_loss_weight
self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)
def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
clip_loss = torch.tensor(0)
if self.clip_loss_weight:
clip_loss = super().forward(image_features, text_features, logit_scale)
clip_loss = self.clip_loss_weight * clip_loss
caption_loss = self.caption_loss(
logits.permute(0, 2, 1),
labels,
)
caption_loss = caption_loss * self.caption_loss_weight
if output_dict:
return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}
return clip_loss, caption_loss
class DistillClipLoss(ClipLoss):
def dist_loss(self, teacher_logits, student_logits):
return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)
def forward(
self,
image_features,
text_features,
logit_scale,
dist_image_features,
dist_text_features,
dist_logit_scale,
output_dict=False,
):
logits_per_image, logits_per_text = \
self.get_logits(image_features, text_features, logit_scale)
dist_logits_per_image, dist_logits_per_text = \
self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)
labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])
contrastive_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
distill_loss = (
self.dist_loss(dist_logits_per_image, logits_per_image) +
self.dist_loss(dist_logits_per_text, logits_per_text)
) / 2
if output_dict:
return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}
return contrastive_loss, distill_loss
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
tensor_recv = torch.zeros_like(tensor)
send_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor,
to_rank,
group=group,
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv,
from_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
return tensor_recv
def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
tensor_from_left = torch.zeros_like(tensor_to_right)
tensor_from_right = torch.zeros_like(tensor_to_left)
send_op_left = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_left,
left_rank,
group=group,
)
send_op_right = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_right,
right_rank,
group=group,
)
recv_op_left = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_left,
left_rank,
group=group,
)
recv_op_right = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_right,
right_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])
for req in reqs:
req.wait()
return tensor_from_right, tensor_from_left
class NeighbourExchange(torch.autograd.Function):
@staticmethod
def forward(ctx, from_rank, to_rank, group, tensor):
ctx.group = group
ctx.from_rank = from_rank
ctx.to_rank = to_rank
return neighbour_exchange(from_rank, to_rank, tensor, group=group)
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)
def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
class NeighbourExchangeBidir(torch.autograd.Function):
@staticmethod
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
ctx.group = group
ctx.left_rank = left_rank
ctx.right_rank = right_rank
return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)
@staticmethod
def backward(ctx, *grad_outputs):
return (None, None, None) + \
NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)
def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)
class SigLipLoss(nn.Module):
""" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343
@article{zhai2023sigmoid,
title={Sigmoid loss for language image pre-training},
author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},
journal={arXiv preprint arXiv:2303.15343},
year={2023}
}
"""
def __init__(
self,
cache_labels=False,
rank=0,
world_size=1,
bidir=True,
use_horovod=False,
):
super().__init__()
self.cache_labels = cache_labels
self.rank = rank
self.world_size = world_size
assert not use_horovod # FIXME need to look at hvd ops for ring transfers
self.use_horovod = use_horovod
self.bidir = bidir
# cache state FIXME cache not currently used, worthwhile?
self.prev_num_logits = 0
self.labels = {}
def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:
labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)
if not negative_only:
labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels
return labels
def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):
logits = logit_scale * image_features @ text_features.T
if logit_bias is not None:
logits += logit_bias
return logits
def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):
logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)
labels = self.get_ground_truth(
image_features.device,
image_features.dtype,
image_features.shape[0],
negative_only=negative_only,
)
loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]
return loss
def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):
loss = self._loss(image_features, text_features, logit_scale, logit_bias)
if self.world_size > 1:
# exchange text features w/ neighbour world_size - 1 times
right_rank = (self.rank + 1) % self.world_size
left_rank = (self.rank - 1 + self.world_size) % self.world_size
if self.bidir:
text_features_to_right = text_features_to_left = text_features
num_bidir, remainder = divmod(self.world_size - 1, 2)
for i in range(num_bidir):
text_features_recv = neighbour_exchange_bidir_with_grad(
left_rank,
right_rank,
text_features_to_left,
text_features_to_right,
)
for f in text_features_recv:
loss += self._loss(
image_features,
f,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_left, text_features_to_right = text_features_recv
if remainder:
text_features_recv = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)
loss += self._loss(
image_features,
text_features_recv,
logit_scale,
logit_bias,
negative_only=True,
)
else:
text_features_to_right = text_features
for i in range(self.world_size - 1):
text_features_from_left = neighbour_exchange_with_grad(
left_rank, right_rank, text_features_to_right)
loss += self._loss(
image_features,
text_features_from_left,
logit_scale,
logit_bias,
negative_only=True,
)
text_features_to_right = text_features_from_left
return {"contrastive_loss": loss} if output_dict else loss
================================================
FILE: inf_clip/models/modified_resnet.py
================================================
from collections import OrderedDict
import torch
from torch import nn
from torch.nn import functional as F
from ..utils import freeze_batch_norm_2d
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.act2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.act3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.act1(self.bn1(self.conv1(x)))
out = self.act2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.act3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0.,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
super().__init__()
self.output_dim = output_dim
self.image_size = image_size
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.act3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
self.init_parameters()
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def init_parameters(self):
if self.attnpool is not None:
std = self.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
for param in self.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass
def stem(self, x):
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
================================================
FILE: inf_clip/models/pos_embed.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
================================================
FILE: inf_clip/models/timm_model.py
================================================
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
try:
import timm
from timm.models.layers import Mlp, to_2tuple
try:
# old timm imports < 0.8.1
from timm.models.layers.attention_pool2d import RotAttentionPool2d
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
except ImportError:
# new timm imports >= 0.8.1
from timm.layers import RotAttentionPool2d
from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
timm = None
from ..utils import freeze_batch_norm_2d
class TimmModel(nn.Module):
""" timm model adapter
"""
def __init__(
self,
model_name,
embed_dim,
image_size=224,
pool='avg',
proj='linear',
proj_bias=False,
drop=0.,
drop_path=None,
patch_drop=None,
pretrained=False,
):
super().__init__()
if timm is None:
raise RuntimeError("Please `pip install timm` to use timm models.")
self.image_size = to_2tuple(image_size)
# setup kwargs that may not be common across all models
timm_kwargs = {}
if drop_path is not None:
timm_kwargs['drop_path_rate'] = drop_path
if patch_drop is not None:
timm_kwargs['patch_drop_rate'] = patch_drop
custom_pool = pool in ('abs_attn', 'rot_attn')
if proj:
assert proj in ("linear", "mlp", "none")
extra_proj = proj in ("linear", "mlp")
if not extra_proj and not custom_pool:
# use network classifier head as projection if no proj specified and no custom pooling used
# if projection is explicitly set to "none" will be pass through from network trunk
proj_dim = 0 if proj == 'none' else embed_dim
self.trunk = timm.create_model(
model_name,
num_classes=proj_dim,
global_pool=pool,
pretrained=pretrained,
**timm_kwargs,
)
prev_chs = embed_dim
else:
self.trunk = timm.create_model(
model_name,
pretrained=pretrained,
**timm_kwargs,
)
feat_size = self.trunk.default_cfg.get('pool_size', None)
feature_ndim = 1 if not feat_size else 2
if custom_pool:
assert feature_ndim == 2
# if attn pooling used, remove both classifier and default pool
self.trunk.reset_classifier(0, global_pool='')
else:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs = dict(global_pool=pool) if pool else {}
self.trunk.reset_classifier(0, **reset_kwargs)
prev_chs = self.trunk.num_features
head_layers = OrderedDict()
# Add custom pooling to head
if pool == 'abs_attn':
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
prev_chs = embed_dim
elif pool == 'rot_attn':
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
prev_chs = embed_dim
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if proj == 'linear':
head_layers['drop'] = nn.Dropout(drop)
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
elif proj == 'mlp':
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))
self.head = nn.Sequential(head_layers)
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if not unlocked_groups:
# lock full model
for param in self.trunk.parameters():
param.requires_grad = False
if freeze_bn_stats:
freeze_batch_norm_2d(self.trunk)
else:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try:
# FIXME import here until API stable and in an official release
from timm.models.helpers import group_parameters, group_modules
except ImportError:
raise RuntimeError(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
matcher = self.trunk.group_matcher()
gparams = group_parameters(self.trunk, matcher)
max_layer_id = max(gparams.keys())
max_layer_id = max_layer_id - unlocked_groups
for group_idx in range(max_layer_id + 1):
group = gparams[group_idx]
for param in group:
self.trunk.get_parameter(param).requires_grad = False
if freeze_bn_stats:
gmodules = group_modules(self.trunk, matcher, reverse=True)
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
freeze_batch_norm_2d(self.trunk, gmodules)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
try:
self.trunk.set_grad_checkpointing(enable)
except Exception as e:
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
def forward_trunk(self, x):
return self.trunk(x)
def forward_head(self, x):
return self.head(x)
def forward(self, x):
x = self.forward_trunk(x)
x = self.forward_head(x)
return x
================================================
FILE: inf_clip/models/tokenizer.py
================================================
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
import random
import string
from functools import lru_cache, partial
from typing import Callable, List, Optional, Union
import warnings
import ftfy
import numpy as np
import regex as re
import torch
# https://stackoverflow.com/q/62691279
os.environ["TOKENIZERS_PARALLELISM"] = "false"
_nltk_init = False
DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = " ".join(text.split())
text = text.strip()
return text
def _clean_canonicalize(x):
# basic, remove whitespace, remove punctuation, lower case
return canonicalize_text(basic_clean(x))
def _clean_lower(x):
# basic, remove whitespace, lower case
return whitespace_clean(basic_clean(x)).lower()
def _clean_whitespace(x):
# basic, remove whitespace
return whitespace_clean(basic_clean(x))
def get_clean_fn(type: str):
if type == 'canonicalize':
return _clean_canonicalize
elif type == 'lower':
return _clean_lower
elif type == 'whitespace':
return _clean_whitespace
else:
assert False, f"Invalid clean function ({type})."
def canonicalize_text(
text,
*,
keep_punctuation_exact_string=None,
trans_punctuation: dict = str.maketrans("", "", string.punctuation),
):
"""Returns canonicalized `text` (lowercase and punctuation removed).
From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
Args:
text: string to be canonicalized.
keep_punctuation_exact_string: If provided, then this exact string kept.
For example providing '{}' will keep any occurrences of '{}' (but will
still remove '{' and '}' that appear separately).
"""
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
part.translate(trans_punctuation)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = text.translate(trans_punctuation)
text = text.lower()
text = " ".join(text.split())
return text.strip()
class SimpleTokenizer(object):
def __init__(
self,
bpe_path: str = default_bpe(),
additional_special_tokens: Optional[List[str]] = None,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'lower',
reduction_mask: str = ''
):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
special_tokens = ['', '']
if additional_special_tokens:
special_tokens += additional_special_tokens
vocab.extend(special_tokens)
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {t:t for t in special_tokens}
special = "|".join(special_tokens)
self.pat = re.compile(
special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE,
)
self.vocab_size = len(self.encoder)
self.all_special_ids = [self.encoder[t] for t in special_tokens]
self.sot_token_id = self.all_special_ids[0]
self.eot_token_id = self.all_special_ids[1]
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '',)
pairs = get_pairs(word)
if not pairs:
return token+''
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except Exception:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = self.clean_fn(text)
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
return text
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:
""" Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length'
if self.reduction_fn is not None:
# use reduction strategy for tokenize if set, otherwise default to truncation below
return self.reduction_fn(
texts,
context_length=context_length,
sot_token_id=self.sot_token_id,
eot_token_id=self.eot_token_id,
encode_fn=self.encode,
)
all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = self.eot_token_id
result[i, :len(tokens)] = torch.tensor(tokens)
return result
_tokenizer = SimpleTokenizer()
def decode(output_ids: torch.Tensor):
output_ids = output_ids.cpu().numpy()
return _tokenizer.decode(output_ids)
def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:
return _tokenizer(texts, context_length=context_length)
def random_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
shuffle: bool = False,
):
all_tokens = [encode_fn(text) for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
tokens = torch.tensor(tokens)
num_tokens = len(tokens)
if num_tokens > context_length - 2: # 2 for sot and eot token
num_keep = context_length - 2
indices = torch.randperm(len(tokens))
indices = indices[:num_keep]
if not shuffle:
indices = indices.msort()
tokens = tokens[indices]
num_tokens = num_keep
result[i, 0] = sot_token_id
result[i, 1:num_tokens + 1] = tokens
result[i, num_tokens + 1] = eot_token_id
return result
def simple_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
):
all_tokens = [encode_fn(text) for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
num_tokens = len(tokens)
if num_tokens > context_length - 2: # 2 for sot and eot token
num_keep = context_length - 2
start_index = random.randint(0, num_tokens - num_keep) # high is incl
tokens = tokens[start_index: start_index + num_keep]
tokens = [sot_token_id] + tokens + [eot_token_id]
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def syntax_mask_tokenize(
texts: Union[str, List[str]],
context_length: int,
sot_token_id: int,
eot_token_id: int,
encode_fn: Callable,
) -> torch.LongTensor:
""" Returns the tokenized representation of given input string(s).
Apply syntax masking before tokenize.
"""
import nltk
global _nltk_init
if not _nltk_init:
# run them for the first time
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
_nltk_init = True
def get_order(x):
if x.startswith('NN'):
return 1
elif x.startswith('JJ'):
return 2
elif x.startswith('VB'):
return 3
else:
return 4
# syntax masking
new_texts = []
for text in texts:
list_tokens = nltk.tokenize.word_tokenize(text)
pos_tags = nltk.pos_tag(list_tokens)
# sample the words by get_order method
order_list = [get_order(tag) for _, tag in pos_tags]
sorted_ids = np.argsort(np.array(order_list))
sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens
sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens
new_text = ''
for token in sampled_tokens:
new_text = new_text + str(token) + ' '
new_text = new_text.strip()
new_texts.append(new_text)
texts = new_texts
all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
# still need first truncate because some words produces two tokens
if len(tokens) > context_length:
tokens = tokens[:context_length] # Truncate
tokens[-1] = eot_token_id
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def get_reduction_mask_fn(type: str):
""" Choose strategy for dropping (masking) tokens to achieve target context length"""
assert type in ('simple', 'random', 'shuffle', 'syntax')
if type == 'simple':
return simple_mask_tokenize # randomly select block [start:end]
elif type == 'random':
return random_mask_tokenize # randomly drop tokens (keep order)
elif type == 'shuffle':
return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order)
elif type == 'syntax':
return syntax_mask_tokenize # randomly drop prioritized by syntax
class HFTokenizer:
"""HuggingFace tokenizer wrapper"""
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,
clean: str = 'whitespace',
strip_sep_token: bool = False,
language: Optional[str] = None,
**kwargs
):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs)
set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)
if callable(set_lang_fn):
self.set_lang_fn = set_lang_fn
if language is not None:
self.set_language(language)
self.context_length = context_length
self.clean_fn = get_clean_fn(clean)
self.strip_sep_token = strip_sep_token
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length in class init or call.'
texts = [self.clean_fn(text) for text in texts]
input_ids = self.tokenizer.batch_encode_plus(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
).input_ids
if self.strip_sep_token:
input_ids = torch.where(
input_ids == self.tokenizer.sep_token_id,
torch.zeros_like(input_ids),
input_ids,
)
return input_ids
def set_language(self, src_lang):
if hasattr(self, 'set_lang_fn'):
self.set_lang_fn(src_lang)
else:
warnings.warn('Cannot set language for the tokenizer.')
class SigLipTokenizer:
"""HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs
"""
VOCAB_FILES = {
# english, vocab_size=32_000
"c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model",
# used in multilingual models (mT5, PaLI), vocab_size=250_000
"mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model",
}
def __init__(
self,
tokenizer_name: str,
context_length: Optional[int] = 64,
):
from transformers import T5TokenizerFast
if tokenizer_name in self.VOCAB_FILES:
# FIXME temporary hack?
import tempfile
import fsspec
vocab_file = self.VOCAB_FILES[tokenizer_name]
with tempfile.NamedTemporaryFile('wb') as dst:
with fsspec.open(vocab_file, 'rb') as src:
dst.write(src.read())
self.tokenizer = T5TokenizerFast(dst.name, legacy=False)
else:
self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False)
self.tokenizer.pad_token_id = 1
self.tokenizer.eos_token_id = 1
self.context_length = context_length
def save_pretrained(self, dest):
self.tokenizer.save_pretrained(dest)
def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if isinstance(texts, str):
texts = [texts]
context_length = context_length or self.context_length
assert context_length, 'Please set a valid context length in class init or call.'
texts = [canonicalize_text(basic_clean(text)) for text in texts]
output = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
truncation=True,
)
return output.input_ids
================================================
FILE: inf_clip/models/transform.py
================================================
import numbers
import random
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop, ColorJitter, Grayscale
from ..constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from ..utils import to_2tuple
@dataclass
class PreprocessCfg:
size: Union[int, Tuple[int, int]] = 224
mode: str = 'RGB'
mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
std: Tuple[float, ...] = OPENAI_DATASET_STD
interpolation: str = 'bicubic'
resize_mode: str = 'shortest'
fill_color: int = 0
def __post_init__(self):
assert self.mode in ('RGB',)
@property
def num_channels(self):
return 3
@property
def input_size(self):
return (self.num_channels,) + to_2tuple(self.size)
_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
def merge_preprocess_dict(
base: Union[PreprocessCfg, Dict],
overlay: Dict,
):
""" Merge overlay key-value pairs on top of base preprocess cfg or dict.
Input dicts are filtered based on PreprocessCfg fields.
"""
if isinstance(base, PreprocessCfg):
base_clean = asdict(base)
else:
base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
if overlay:
overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}
base_clean.update(overlay_clean)
return base_clean
def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):
return merge_preprocess_dict(base, kwargs)
@dataclass
class AugmentationCfg:
scale: Tuple[float, float] = (0.9, 1.0)
ratio: Optional[Tuple[float, float]] = None
color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None
re_prob: Optional[float] = None
re_count: Optional[int] = None
use_timm: bool = False
# params for simclr_jitter_gray
color_jitter_prob: float = None
gray_scale_prob: float = None
def _setup_size(size, error_msg):
if isinstance(size, numbers.Number):
return int(size), int(size)
if isinstance(size, Sequence) and len(size) == 1:
return size[0], size[0]
if len(size) != 2:
raise ValueError(error_msg)
return size
class ResizeKeepRatio:
""" Resize and Keep Ratio
Copy & paste from `timm`
"""
def __init__(
self,
size,
longest=0.,
interpolation=InterpolationMode.BICUBIC,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
self.interpolation = interpolation
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
self.random_scale_prob = random_scale_prob
self.random_scale_range = random_scale_range
self.random_aspect_prob = random_aspect_prob
self.random_aspect_range = random_aspect_range
@staticmethod
def get_params(
img,
target_size,
longest,
random_scale_prob=0.,
random_scale_range=(0.85, 1.05),
random_aspect_prob=0.,
random_aspect_range=(0.9, 1.11)
):
"""Get parameters
"""
source_size = img.size[::-1] # h, w
h, w = source_size
target_h, target_w = target_size
ratio_h = h / target_h
ratio_w = w / target_w
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
if random_scale_prob > 0 and random.random() < random_scale_prob:
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
ratio_factor = (ratio_factor, ratio_factor)
else:
ratio_factor = (1., 1.)
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
return size
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
"""
size = self.get_params(
img, self.size, self.longest,
self.random_scale_prob, self.random_scale_range,
self.random_aspect_prob, self.random_aspect_range
)
img = F.resize(img, size, self.interpolation)
return img
def __repr__(self):
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += f', interpolation={self.interpolation})'
format_string += f', longest={self.longest:.3f})'
return format_string
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
"""Center crops and/or pads the given image.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
fill (int, Tuple[int]): Padding color
Returns:
PIL Image or Tensor: Cropped image.
"""
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = F.get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = F.pad(img, padding_ltrb, fill=fill)
_, image_height, image_width = F.get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
class CenterCropOrPad(torch.nn.Module):
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
"""
def __init__(self, size, fill=0):
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.fill = fill
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped.
Returns:
PIL Image or Tensor: Cropped image.
"""
return center_crop_or_pad(img, self.size, fill=self.fill)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(size={self.size})"
def _convert_to_rgb(image):
return image.convert('RGB')
class color_jitter(object):
"""
Apply Color Jitter to the PIL image with a specified probability.
"""
def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):
assert 0. <= p <= 1.
self.p = p
self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class gray_scale(object):
"""
Apply Gray Scale to the PIL image with a specified probability.
"""
def __init__(self, p=0.2):
assert 0. <= p <= 1.
self.p = p
self.transf = Grayscale(num_output_channels=3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def image_transform(
image_size: Union[int, Tuple[int, int]],
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_mode: Optional[str] = None,
interpolation: Optional[str] = None,
fill_color: int = 0,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
mean = mean or OPENAI_DATASET_MEAN
if not isinstance(mean, (list, tuple)):
mean = (mean,) * 3
std = std or OPENAI_DATASET_STD
if not isinstance(std, (list, tuple)):
std = (std,) * 3
interpolation = interpolation or 'bicubic'
assert interpolation in ['bicubic', 'bilinear', 'random']
# NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set
interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
resize_mode = resize_mode or 'shortest'
assert resize_mode in ('shortest', 'longest', 'squash')
if isinstance(aug_cfg, dict):
aug_cfg = AugmentationCfg(**aug_cfg)
else:
aug_cfg = aug_cfg or AugmentationCfg()
normalize = Normalize(mean=mean, std=std)
if is_train:
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
use_timm = aug_cfg_dict.pop('use_timm', False)
if use_timm:
from timm.data import create_transform # timm can still be optional
if isinstance(image_size, (tuple, list)):
assert len(image_size) >= 2
input_size = (3,) + image_size[-2:]
else:
input_size = (3, image_size, image_size)
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
# drop extra non-timm items
aug_cfg_dict.pop('color_jitter_prob', None)
aug_cfg_dict.pop('gray_scale_prob', None)
train_transform = create_transform(
input_size=input_size,
is_training=True,
hflip=0.,
mean=mean,
std=std,
re_mode='pixel',
interpolation=interpolation,
**aug_cfg_dict,
)
else:
train_transform = [
RandomResizedCrop(
image_size,
scale=aug_cfg_dict.pop('scale'),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
]
if aug_cfg.color_jitter_prob:
assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
train_transform.extend([
color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)
])
if aug_cfg.gray_scale_prob:
train_transform.extend([
gray_scale(aug_cfg.gray_scale_prob)
])
train_transform.extend([
ToTensor(),
normalize,
])
train_transform = Compose(train_transform)
if aug_cfg_dict:
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
return train_transform
else:
if resize_mode == 'longest':
transforms = [
ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
CenterCropOrPad(image_size, fill=fill_color)
]
elif resize_mode == 'squash':
if isinstance(image_size, int):
image_size = (image_size, image_size)
transforms = [
Resize(image_size, interpolation=interpolation_mode),
]
else:
assert resize_mode == 'shortest'
if not isinstance(image_size, (tuple, list)):
image_size = (image_size, image_size)
if image_size[0] == image_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
transforms = [
Resize(image_size[0], interpolation=interpolation_mode)
]
else:
# resize shortest edge to matching target dim for non-square target
transforms = [ResizeKeepRatio(image_size)]
transforms += [CenterCrop(image_size)]
transforms.extend([
_convert_to_rgb,
ToTensor(),
normalize,
])
return Compose(transforms)
def image_transform_v2(
cfg: PreprocessCfg,
is_train: bool,
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
return image_transform(
image_size=cfg.size,
is_train=is_train,
mean=cfg.mean,
std=cfg.std,
interpolation=cfg.interpolation,
resize_mode=cfg.resize_mode,
fill_color=cfg.fill_color,
aug_cfg=aug_cfg,
)
================================================
FILE: inf_clip/models/transformer.py
================================================
from collections import OrderedDict
import math
from typing import Callable, List, Optional, Sequence, Tuple, Union
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from .pos_embed import get_2d_sincos_pos_embed
from ..utils import to_2tuple
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x.to(orig_type)
class QuickGELU(nn.Module):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
def forward(self, x):
if not self.training or self.prob == 0.:
return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
scaled_cosine: bool = False,
scale_heads: bool = False,
logit_scale_max: float = math.log(1. / 0.01),
batch_first: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.logit_scale_max = logit_scale_max
self.batch_first = batch_first
self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
if qkv_bias:
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
else:
self.in_proj_bias = None
if self.scaled_cosine:
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
else:
self.logit_scale = None
self.attn_drop = nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = nn.Linear(dim, dim)
self.out_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
if self.batch_first:
x = x.transpose(0, 1)
L, N, C = x.shape
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)
k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)
v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)
if attn_mask is not None and attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
attn_mask = new_attn_mask
if self.logit_scale is not None:
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
attn = attn.view(N, self.num_heads, L, L) * logit_scale
attn = attn.view(-1, L, L)
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
else:
if self.use_fsdpa:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = torch.bmm(q, k.transpose(-1, -2))
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = torch.bmm(attn, v)
if self.head_scale is not None:
x = x.view(N, self.num_heads, L, C) * self.head_scale
x = x.view(-1, L, C)
x = x.transpose(0, 1).reshape(L, N, C)
if self.batch_first:
x = x.transpose(0, 1)
x = self.out_proj(x)
x = self.out_drop(x)
return x
class AttentionalPooler(nn.Module):
def __init__(
self,
d_model: int,
context_dim: int,
n_head: int = 8,
n_queries: int = 256,
norm_layer: Callable = LayerNorm,
):
super().__init__()
self.query = nn.Parameter(torch.randn(n_queries, d_model))
self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)
self.ln_q = norm_layer(d_model)
self.ln_k = norm_layer(context_dim)
def forward(self, x: torch.Tensor):
N = x.shape[0]
x = self.ln_k(x)
q = self.ln_q(self.query)
out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]
return out
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
is_cross_attention: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(
q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask
)[0]
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
class CustomResidualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
scale_cosine_attn: bool = False,
scale_heads: bool = False,
scale_attn: bool = False,
scale_fc: bool = False,
batch_first: bool = True,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = Attention(
d_model,
n_head,
scaled_cosine=scale_cosine_attn,
scale_heads=scale_heads,
batch_first=batch_first,
)
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
def get_reference_weight(self):
return self.mlp.c_fc.weight
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
x = x + self.ls_2(self.mlp(self.ln_2(x)))
return x
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first
self.grad_checkpointing = False
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
return self.resblocks[0].mlp.c_fc.int8_original_dtype
return self.resblocks[0].mlp.c_fc.weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1).contiguous() # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # LND -> NLD
return x
class CustomTransformer(nn.Module):
""" A custom transformer that can use different block types. """
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
batch_first: bool = True,
block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',
):
super().__init__()
self.width = width
self.layers = layers
self.batch_first = batch_first # run trasnformer stack in batch first (N, L, D)
self.grad_checkpointing = False
if isinstance(block_types, str):
block_types = [block_types] * layers
assert len(block_types) == layers
def _create_block(bt: str):
if bt == 'CustomResidualAttentionBlock':
return CustomResidualAttentionBlock(
width,
heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
else:
assert False
self.resblocks = nn.ModuleList([
_create_block(bt)
for bt in block_types
])
def get_cast_dtype(self) -> torch.dtype:
weight = self.resblocks[0].get_reference_weight()
if hasattr(weight, 'int8_original_dtype'):
return weight.int8_original_dtype
return weight.dtype
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
for r in self.resblocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
x = checkpoint(r, x, None, None, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
if not self.batch_first:
x = x.transpose(0, 1) # NLD -> LND
return x
class VisionTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
ls_init_value: float = None,
attentional_pool: bool = False,
attn_pooler_queries: int = 256,
attn_pooler_heads: int = 8,
output_dim: int = 512,
patch_dropout: float = 0.,
no_ln_pre: bool = False,
pos_embed_type: str = 'learnable',
pool_type: str = 'tok',
final_ln_after_pool: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('tok', 'avg', 'none')
self.output_tokens = output_tokens
image_height, image_width = self.image_size = to_2tuple(image_size)
patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
if pos_embed_type == 'learnable':
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
elif pos_embed_type == 'sin_cos_2d':
# fixed sin-cos embedding
assert self.grid_size[0] == self.grid_size[1],\
'currently sin cos 2d pos embedding only supports square input'
self.positional_embedding = nn.Parameter(
torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)
pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)
self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())
else:
raise ValueError
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)
self.transformer = Transformer(
width,
layers,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
if attentional_pool:
if isinstance(attentional_pool, str):
self.attn_pool_type = attentional_pool
self.pool_type = 'none'
if attentional_pool in ('parallel', 'cascade'):
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=1,
)
else:
assert False
else:
self.attn_pool_type = ''
self.pool_type = pool_type
self.attn_pool = AttentionalPooler(
output_dim,
width,
n_head=attn_pooler_heads,
n_queries=attn_pooler_queries,
)
self.attn_pool_contrastive = None
pool_dim = output_dim
else:
self.attn_pool = None
pool_dim = width
self.pool_type = pool_type
self.ln_post = norm_layer(pool_dim)
self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))
self.init_parameters()
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False
if unlocked_groups != 0:
groups = [
[
self.conv1,
self.class_embedding,
self.positional_embedding,
self.ln_pre,
],
*self.transformer.resblocks[:-1],
[
self.transformer.resblocks[-1],
self.ln_post,
],
self.proj,
]
def _unlock(x):
if isinstance(x, Sequence):
for g in x:
_unlock(g)
else:
if isinstance(x, torch.nn.Parameter):
x.requires_grad = True
else:
for p in x.parameters():
p.requires_grad = True
_unlock(groups[-unlocked_groups:])
def init_parameters(self):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
elif self.pool_type == 'tok':
pooled, tokens = x[:, 0], x[:, 1:]
else:
pooled = tokens = x
return pooled, tokens
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
# shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.patch_dropout(x)
x = self.ln_pre(x)
x = self.transformer(x)
if self.attn_pool is not None:
if self.attn_pool_contrastive is not None:
# This is untested, WIP pooling that should match paper
x = self.ln_post(x) # TBD LN first or separate one after each pool?
tokens = self.attn_pool(x)
if self.attn_pool_type == 'parallel':
pooled = self.attn_pool_contrastive(x)
else:
assert self.attn_pool_type == 'cascade'
pooled = self.attn_pool_contrastive(tokens)
else:
# this is the original OpenCLIP CoCa setup, does not match paper
x = self.attn_pool(x)
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
elif self.final_ln_after_pool:
pooled, tokens = self._global_pool(x)
pooled = self.ln_post(pooled)
else:
x = self.ln_post(x)
pooled, tokens = self._global_pool(x)
if self.proj is not None:
pooled = pooled @ self.proj
if self.output_tokens:
return pooled, tokens
return pooled
def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
if pool_type == 'first':
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == 'last':
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == 'argmax':
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class TextTransformer(nn.Module):
output_tokens: torch.jit.Final[bool]
def __init__(
self,
context_length: int = 77,
vocab_size: int = 49408,
width: int = 512,
heads: int = 8,
layers: int = 12,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
output_dim: int = 512,
embed_cls: bool = False,
no_causal_mask: bool = False,
pad_id: int = 0,
pool_type: str = 'argmax',
proj_bias: bool = False,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_tokens: bool = False,
):
super().__init__()
assert pool_type in ('first', 'last', 'argmax', 'none')
self.output_tokens = output_tokens
self.num_pos = self.context_length = context_length
self.vocab_size = vocab_size
self.width = width
self.output_dim = output_dim
self.heads = heads
self.pad_id = pad_id
self.pool_type = pool_type
self.token_embedding = nn.Embedding(vocab_size, width)
if embed_cls:
self.cls_emb = nn.Parameter(torch.empty(width))
self.num_pos += 1
else:
self.cls_emb = None
self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
self.transformer = Transformer(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.ln_final = norm_layer(width)
if no_causal_mask:
self.attn_mask = None
else:
self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)
if proj_bias:
self.text_projection = nn.Linear(width, output_dim)
else:
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
self.init_parameters()
def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if self.cls_emb is not None:
nn.init.normal_(self.cls_emb, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)
if self.text_projection.bias is not None:
nn.init.zeros_(self.text_projection.bias)
else:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_cls_mask(self, text, cast_dtype: torch.dtype):
cls_mask = (text != self.pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
return additive_mask
def forward(self, text):
cast_dtype = self.transformer.get_cast_dtype()
seq_len = text.shape[1]
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
attn_mask = self.attn_mask
if self.cls_emb is not None:
seq_len += 1
x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)
cls_mask = self.build_cls_mask(text, cast_dtype)
if attn_mask is not None:
attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]
x = x + self.positional_embedding[:seq_len].to(cast_dtype)
x = self.transformer(x, attn_mask=attn_mask)
# x.shape = [batch_size, n_ctx, transformer.width]
if self.cls_emb is not None:
# presence of appended cls embed (CoCa) overrides pool_type, always take last token
pooled, tokens = text_global_pool(x, pool_type='last')
pooled = self.ln_final(pooled) # final LN applied after pooling in this case
else:
x = self.ln_final(x)
pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)
if self.text_projection is not None:
if isinstance(self.text_projection, nn.Linear):
pooled = self.text_projection(pooled)
else:
pooled = pooled @ self.text_projection
if self.output_tokens:
return pooled, tokens
return pooled
class MultimodalTransformer(Transformer):
def __init__(
self,
width: int,
layers: int,
heads: int,
context_length: int = 77,
mlp_ratio: float = 4.0,
ls_init_value: float = None,
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm,
output_dim: int = 512,
batch_first: bool = True,
):
super().__init__(
width=width,
layers=layers,
heads=heads,
mlp_ratio=mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
batch_first=batch_first,
)
self.context_length = context_length
self.cross_attn = nn.ModuleList([
ResidualAttentionBlock(
width,
heads,
mlp_ratio,
ls_init_value=ls_init_value,
act_layer=act_layer,
norm_layer=norm_layer,
is_cross_attention=True,
batch_first=batch_first,
)
for _ in range(layers)
])
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
self.ln_final = norm_layer(width)
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
def init_parameters(self):
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
for block in self.transformer.cross_attn:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, image_embs, text_embs):
seq_len = text_embs.shape[1]
if not self.batch_first:
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND
text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
for resblock, cross_attn in zip(self.resblocks, self.cross_attn):
if self.grad_checkpointing and not torch.jit.is_scripting():
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])
text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)
else:
text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])
text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)
if not self.batch_first:
text_embs = text_embs.permute(1, 0, 2) # LND -> NLD
out = self.ln_final(text_embs)
if self.text_projection is not None:
out = out @ self.text_projection
return out
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
================================================
FILE: inf_clip/openai.py
================================================
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import os
import warnings
from typing import List, Optional, Union
import torch
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
from .models.clip_arch import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
__all__ = ["list_openai_models", "load_openai_model"]
def list_openai_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list_pretrained_models_by_tag('openai')
def load_openai_model(
name: str,
precision: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None,
cache_dir: Optional[str] = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if precision is None:
precision = 'fp32' if device == 'cpu' else 'fp16'
if get_pretrained_url(name, 'openai'):
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
state_dict = torch.load(model_path, map_location="cpu")
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype = get_cast_dtype(precision)
try:
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
except KeyError:
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model = model.to(device)
# FIXME support pure fp16/bf16 precision modes
if precision != 'fp16':
model.float()
if precision == 'bf16':
# for bf16, convert back to low-precision
convert_weights_to_lp(model, dtype=torch.bfloat16)
# add mean / std attributes for consistency with OpenCLIP models
model.visual.image_mean = OPENAI_DATASET_MEAN
model.visual.image_std = OPENAI_DATASET_STD
return model
================================================
FILE: inf_clip/pretrained.py
================================================
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union
import torch
import numpy as np
from tqdm import tqdm
from .models.clip_arch import CLIP, CustomTextCLIP
from .models.transformer import TextTransformer, Transformer
from .constants import (
IMAGENET_MEAN,
IMAGENET_STD,
INCEPTION_MEAN,
INCEPTION_STD,
OPENAI_DATASET_MEAN,
OPENAI_DATASET_STD,
)
__version__ = "2.26.1"
try:
from huggingface_hub import hf_hub_download
hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_download = None
_has_hf_hub = False
def _pcfg(url='', hf_hub='', **kwargs):
# OpenAI / OpenCLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': OPENAI_DATASET_MEAN,
'std': OPENAI_DATASET_STD,
'interpolation': 'bicubic',
'resize_mode': 'shortest',
**kwargs,
}
def _slpcfg(url='', hf_hub='', **kwargs):
# SiGLIP defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': INCEPTION_MEAN,
'std': INCEPTION_STD,
'interpolation': 'bicubic',
'resize_mode': 'squash',
**kwargs,
}
def _apcfg(url='', hf_hub='', **kwargs):
# CLIPA defaults
return {
'url': url,
'hf_hub': hf_hub,
'mean': IMAGENET_MEAN,
'std': IMAGENET_STD,
'interpolation': 'bilinear',
'resize_mode': 'squash',
**kwargs,
}
def _mccfg(url='', hf_hub='', **kwargs):
# MobileCLIP
return {
'url': url,
'hf_hub': hf_hub,
'mean': (0., 0., 0.),
'std': (1., 1., 1.),
'interpolation': 'bilinear',
'resize_mode': 'shortest',
**kwargs,
}
_RN50 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN50_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
cc12m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)
_RN101 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN101_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
yfcc15m=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)
_RN50x4 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
)
_RN50x16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
)
_RN50x64 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
)
_VITB32 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
laion2b_e16=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),
# DataComp-M models
datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),
commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),
commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),
commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),
commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),
commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),
commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),
# DataComp-S models
datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),
commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),
commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),
commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),
commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),
commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),
commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),
)
_VITB32_quickgelu = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"),
)
_VITB32_256 = dict(
datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),
)
_VITB16 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),
# DataComp-L models
datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),
commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),
commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),
commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),
commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),
commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),
commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),
# DFN
dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/')
)
_VITB16_quickgelu = dict(
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"),
)
_VITB16_PLUS_240 = dict(
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)
_VITL14 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
laion400m_e31=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
laion400m_e32=_pcfg(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
laion2b_s32b_b82k=_pcfg(
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
mean=INCEPTION_MEAN, std=INCEPTION_STD),
# DataComp-XL models
datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),
commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),
commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),
commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),
)
_VITL14_quickgelu = dict(
metaclip_400m=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"),
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"),
dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'),
)
_VITL14_336 = dict(
openai=_pcfg(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)
_VITH14 = dict(
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)
_VITH14_quickgelu = dict(
metaclip_fullcc=_pcfg(
"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"),
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14/',
interpolation="bicubic",
resize_mode="squash"
),
)
_VITH14_378_quickgelu = dict(
dfn5b=_pcfg(
hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',
interpolation="bicubic",
resize_mode="squash"
),
)
_VITg14 = dict(
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)
_VITbigG14 = dict(
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)
_robertaViTB32 = dict(
laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)
_xlmRobertaBaseViTB32 = dict(
laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)
_xlmRobertaLargeFrozenViTH14 = dict(
frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)
_convnext_base = dict(
laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)
_convnext_base_w = dict(
laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)
_convnext_base_w_320 = dict(
laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)
_convnext_large_d = dict(
laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)
_convnext_large_d_320 = dict(
laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)
_convnext_xxlarge = dict(
laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)
_coca_VITB32 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)
_coca_VITL14 = dict(
laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)
_PRETRAINED = {
"RN50": _RN50,
"RN50-quickgelu": _RN50_quickgelu,
"RN101": _RN101,
"RN101-quickgelu": _RN101_quickgelu,
"RN50x4": _RN50x4,
"RN50x16": _RN50x16,
"RN50x64": _RN50x64,
"ViT-B-32": _VITB32,
"ViT-B-32-256": _VITB32_256,
"ViT-B-32-quickgelu": _VITB32_quickgelu,
"ViT-B-16": _VITB16,
"ViT-B-16-quickgelu": _VITB16_quickgelu,
"ViT-B-16-plus-240": _VITB16_PLUS_240,
"ViT-L-14": _VITL14,
"ViT-L-14-quickgelu": _VITL14_quickgelu,
"ViT-L-14-336": _VITL14_336,
"ViT-H-14": _VITH14,
"ViT-H-14-quickgelu": _VITH14_quickgelu,
"ViT-H-14-378-quickgelu": _VITH14_378_quickgelu,
"ViT-g-14": _VITg14,
"ViT-bigG-14": _VITbigG14,
"roberta-ViT-B-32": _robertaViTB32,
"xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
"xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
"convnext_base": _convnext_base,
"convnext_base_w": _convnext_base_w,
"convnext_base_w_320": _convnext_base_w_320,
"convnext_large_d": _convnext_large_d,
"convnext_large_d_320": _convnext_large_d_320,
"convnext_xxlarge": _convnext_xxlarge,
"coca_ViT-B-32": _coca_VITB32,
"coca_ViT-L-14": _coca_VITL14,
"EVA01-g-14": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt
laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),
),
"EVA01-g-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt
merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),
),
"EVA02-B-16": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt
merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),
),
"EVA02-L-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt
merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),
),
"EVA02-L-14-336": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt
merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),
),
"EVA02-E-14": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt
laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),
),
"EVA02-E-14-plus": dict(
# from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt
laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),
),
"ViT-B-16-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),
),
"ViT-B-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),
),
"ViT-B-16-SigLIP-i18n-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),
),
"ViT-B-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),
),
"ViT-B-16-SigLIP-512": dict(
webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),
),
"ViT-L-16-SigLIP-256": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),
),
"ViT-L-16-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),
),
"ViT-SO400M-14-SigLIP": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),
),
"ViT-SO400M-14-SigLIP-384": dict(
webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),
),
"ViT-L-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),
),
"ViT-L-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),
),
"ViT-H-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),
),
"ViT-H-14-CLIPA-336": dict(
laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),
),
"ViT-bigG-14-CLIPA": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),
),
"ViT-bigG-14-CLIPA-336": dict(
datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),
),
"nllb-clip-base": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),
),
"nllb-clip-large": dict(
v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),
),
"nllb-clip-base-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),
),
"nllb-clip-large-siglip": dict(
v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),
mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),
),
"MobileCLIP-S1": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),
"MobileCLIP-S2": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),
"MobileCLIP-B": dict(
datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),
datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),
),
"ViTamin-S": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),
),
"ViTamin-S-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),
),
"ViTamin-B": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),
),
"ViTamin-B-LTT": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),
),
"ViTamin-L": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),
),
"ViTamin-L-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),
),
"ViTamin-L-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),
),
"ViTamin-L-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),
),
"ViTamin-L2": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),
),
"ViTamin-L2-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),
),
"ViTamin-L2-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),
),
"ViTamin-L2-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),
),
"ViTamin-XL-256": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),
),
"ViTamin-XL-336": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),
),
"ViTamin-XL-384": dict(
datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),
),
}
def _clean_tag(tag: str):
# normalize pretrained tags
return tag.lower().replace('-', '_')
def list_pretrained(as_str: bool = False):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
def list_pretrained_models_by_tag(tag: str):
""" return all models having the specified pretrain tag """
models = []
tag = _clean_tag(tag)
for k in _PRETRAINED.keys():
if tag in _PRETRAINED[k]:
models.append(k)
return models
def list_pretrained_tags_by_model(model: str):
""" return all pretrain tags for the specified model architecture """
tags = []
if model in _PRETRAINED:
tags.extend(_PRETRAINED[model].keys())
return tags
def is_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return False
return _clean_tag(tag) in _PRETRAINED[model]
def get_pretrained_cfg(model: str, tag: str):
if model not in _PRETRAINED:
return {}
model_pretrained = _PRETRAINED[model]
return model_pretrained.get(_clean_tag(tag), {})
def get_pretrained_url(model: str, tag: str):
cfg = get_pretrained_cfg(model, _clean_tag(tag))
return cfg.get('url', '')
def download_pretrained_from_url(
url: str,
cache_dir: Union[str, None] = None,
):
if not cache_dir:
cache_dir = os.path.expanduser("~/.cache/clip")
os.makedirs(cache_dir, exist_ok=True)
filename = os.path.basename(url)
if 'openaipublic' in url:
expected_sha256 = url.split("/")[-2]
elif 'mlfoundations' in url:
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
else:
expected_sha256 = ''
download_target = os.path.join(cache_dir, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if expected_sha256:
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
else:
return download_target
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return _has_hf_hub
def download_pretrained_from_hf(
model_id: str,
filename: str = 'open_clip_pytorch_model.bin',
revision=None,
cache_dir: Union[str, None] = None,
):
has_hf_hub(True)
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
return cached_file
def download_pretrained(
cfg: Dict,
force_hf_hub: bool = False,
cache_dir: Union[str, None] = None,
):
target = ''
if not cfg:
return target
download_url = cfg.get('url', '')
download_hf_hub = cfg.get('hf_hub', '')
if download_hf_hub and force_hf_hub:
# use HF hub even if url exists
download_url = ''
if download_url:
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
elif download_hf_hub:
has_hf_hub(True)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id, filename = os.path.split(download_hf_hub)
if filename:
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
else:
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
return target
@torch.no_grad()
def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):
""" Load weights from .npz checkpoints for official Google big_vision image-text models
Currently the SigLIP source models are supported and a CustomTextCLIP destination model
w/ timm image encoder.
"""
from timm.layers import resample_patch_embed, resample_abs_pos_embed
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
interpolation = 'bilinear'
antialias = False
def _convert_timm_img(module, prefix):
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:
embed_conv_w = resample_patch_embed(
embed_conv_w,
module.patch_embed.proj.weight.shape[-2:],
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.patch_embed.proj.weight.copy_(embed_conv_w)
module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
if module.cls_token is not None:
module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
if pos_embed_w.shape != module.pos_embed.shape:
assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'
num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)
pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w,
new_size=module.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
module.pos_embed.copy_(pos_embed_w)
mha_sub, b_sub, ln1_sub = (0, 0, 1)
for i, block in enumerate(module.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if module.attn_pool is not None:
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
module.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
module.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
for r in range(2):
getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
def _convert_openclip_transformer(module: Transformer, prefix):
for i, block in enumerate(module.resblocks.children()):
block_prefix = f'{prefix}encoderblock_{i}/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.in_proj_weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
block.attn.in_proj_bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))
block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))
block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))
block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))
block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))
block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))
def _convert_openclip_txt(module: TextTransformer, prefix):
module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))
pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)
module.positional_embedding.copy_(pos_embed_w)
_convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')
module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))
module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))
module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))
_convert_timm_img(model.visual.trunk, 'params/img/')
_convert_openclip_txt(model.text, 'params/txt/')
model.logit_bias.copy_(_n2p(w['params/b'])[0])
model.logit_scale.copy_(_n2p(w['params/t'])[0])
@torch.no_grad()
def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):
def _convert_timm_img(state_dict):
if fastvit:
from timm.models.fastvit import checkpoint_filter_fn
else:
from timm.models.vision_transformer_hybrid import checkpoint_filter_fn
timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)
timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}
return timm_state_dict
def _convert_openclip_txt(state_dict, prefix='text_encoder.'):
text_dict = {}
for k, v in state_dict.items():
if not k.startswith(prefix):
continue
k = k.replace(prefix, '')
k = k.replace('projection_layer', 'text_projection')
k = k.replace('embedding_layer', 'token_embedding')
if k.startswith('positional_embedding.pos_embed.pos_embed'):
k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')
v = v.squeeze()
k = k.replace('final_layer_norm', 'ln_final')
k = k.replace('pre_norm_mha.0', 'ln_1')
k = k.replace('pre_norm_mha.1', 'attn')
k = k.replace('pre_norm_ffn.0', 'ln_2')
k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')
k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')
k = k.replace('qkv_proj.weight', 'in_proj_weight')
k = k.replace('qkv_proj.bias', 'in_proj_bias')
k = k.replace('transformer.', 'transformer.resblocks.')
text_dict['text.' + k] = v
return text_dict
image_dict = _convert_timm_img(state_dict)
text_dict = _convert_openclip_txt(state_dict)
out_dict = {**image_dict, **text_dict}
out_dict['logit_scale'] = state_dict['logit_scale']
return out_dict
def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
# Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)
state_dict = convert_mobile_clip_state_dict(model, state_dict)
if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:
# convert b model
state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)
return state_dict
================================================
FILE: inf_clip/train/data.py
================================================
import ast
import json
import logging
import math
import os
import random
import sys
import braceexpand
from dataclasses import dataclass
from multiprocessing import Value
import numpy as np
import pandas as pd
import torch
import torchvision.datasets as datasets
import webdataset as wds
from PIL import Image
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
from torch.utils.data.distributed import DistributedSampler
from webdataset.filters import _shuffle
from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
class CsvDataset(Dataset):
def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None):
logging.debug(f'Loading csv data from {input_filename}.')
df = pd.read_csv(input_filename, sep=sep)
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
logging.debug('Done loading data.')
self.tokenize = tokenizer
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
texts = self.tokenize([str(self.captions[idx])])[0]
return images, texts
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value('i', epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
@dataclass
class DataInfo:
dataloader: DataLoader
sampler: DistributedSampler = None
shared_epoch: SharedEpoch = None
def set_epoch(self, epoch):
if self.shared_epoch is not None:
self.shared_epoch.set_value(epoch)
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def expand_urls(urls, weights=None):
if weights is None:
expanded_urls = wds.shardlists.expand_urls(urls)
return expanded_urls, None
if isinstance(urls, str):
urllist = urls.split("::")
weights = weights.split('::')
assert len(weights) == len(urllist),\
f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
weights = [float(weight) for weight in weights]
all_urls, all_weights = [], []
for url, weight in zip(urllist, weights):
expanded_url = list(braceexpand.braceexpand(url))
expanded_weights = [weight for _ in expanded_url]
all_urls.extend(expanded_url)
all_weights.extend(expanded_weights)
return all_urls, all_weights
else:
all_urls = list(urls)
return all_urls, weights
def get_dataset_size(shards):
shards_list, _ = expand_urls(shards)
dir_path = os.path.dirname(shards_list[0])
sizes_filename = os.path.join(dir_path, 'sizes.json')
len_filename = os.path.join(dir_path, '__len__')
if os.path.exists(sizes_filename):
sizes = json.load(open(sizes_filename, 'r'))
total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
elif os.path.exists(len_filename):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size = ast.literal_eval(open(len_filename, 'r').read())
else:
total_size = None # num samples undefined
# some common dataset sizes (at time of authors last download)
# CC3M (train): 2905954
# CC12M: 10968539
# LAION-400M: 407332084
# LAION-2B (english): 2170337258
num_shards = len(shards_list)
return total_size, num_shards
def get_imagenet(args, preprocess_fns, split):
assert split in ["train", "val", "v2"]
is_train = split == "train"
preprocess_train, preprocess_val = preprocess_fns
if split == "v2":
from imagenetv2_pytorch import ImageNetV2Dataset
dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
else:
if is_train:
data_path = args.imagenet_train
preprocess_fn = preprocess_train
else:
data_path = args.imagenet_val
preprocess_fn = preprocess_val
assert data_path
dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
if is_train:
idxs = np.zeros(len(dataset.targets))
target_array = np.array(dataset.targets)
k = 50
for c in range(1000):
m = target_array == c
n = len(idxs[m])
arr = np.zeros(n)
arr[:k] = 1
np.random.shuffle(arr)
idxs[m] = arr
idxs = idxs.astype('int')
sampler = SubsetRandomSampler(np.where(idxs)[0])
else:
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.workers,
sampler=sampler,
pin_memory=True,
)
return DataInfo(dataloader=dataloader, sampler=sampler)
def count_samples(dataloader):
os.environ["WDS_EPOCH"] = "0"
n_elements, n_batches = 0, 0
for images, texts in dataloader:
n_batches += 1
n_elements += len(images)
assert len(images) == len(texts)
return n_elements, n_batches
def filter_no_caption_or_no_image(sample):
has_caption = ('txt' in sample or 'json' in sample)
has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
return has_caption and has_image
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
# FIXME this is a bit of a hack to handle the fact that the CC3M/LAION400m dataset has some empty files.
try:
fname, value = filesample["fname"], filesample["data"]
except KeyError as exn:
continue
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def pytorch_worker_seed(increment=0):
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour using the seed already created for pytorch dataloader workers if it exists
seed = worker_info.seed
if increment:
# space out seed increments so they can't overlap across workers in different iterations
seed += increment * max(1, worker_info.num_workers)
return seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
def json_fetch(data, key='caption'):
for sample in data:
if 'json' in sample:
if isinstance(sample['json'], dict):
value = sample['json'].get(key, None)
else:
value = sample['json']
else:
value = sample['txt']
if isinstance(value, str):
sample['txt'] = [value]
elif isinstance(value, list):
sample['txt'] = value
else:
# print(f"Expected {key} to be a string or list of strings, got {type(value)} {sample}")
continue
yield sample
_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000
class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
weights=None,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls, weights = expand_urls(urls, weights)
self.urls = urls
self.weights = weights
if self.weights is not None:
assert len(self.urls) == len(self.weights),\
f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic
if self.worker_seed is None:
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
seed = pytorch_worker_seed(epoch)
else:
seed = self.worker_seed() + epoch
self.rng.seed(seed)
for _ in range(self.nshards):
if self.weights is None:
yield dict(url=self.rng.choice(self.urls))
else:
yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])
def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None):
input_shards = args.train_data if is_train else args.val_data
assert input_shards is not None
resampled = getattr(args, 'dataset_resampled', False) and is_train
num_shards = None
if is_train:
if args.train_num_samples is not None:
num_samples = args.train_num_samples
else:
num_samples, num_shards = get_dataset_size(input_shards)
if not num_samples:
raise RuntimeError(
'Currently, the number of dataset samples must be specified for the training dataset. '
'Please specify it via `--train-num-samples` if no dataset length info is present.')
else:
# Eval will just exhaust the iterator if the size is not specified.
num_samples = args.val_num_samples or 0
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
if is_train and args.train_data_upsampling_factors is not None:
assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
if resampled:
pipeline = [ResampledShards2(
input_shards,
weights=args.train_data_upsampling_factors,
deterministic=True,
epoch=shared_epoch,
)]
else:
pipeline = [wds.SimpleShardList(input_shards)]
# at this point we have an iterator over all the shards
if is_train:
if not resampled:
pipeline.extend([
detshuffle2(
bufsize=_SHARD_SHUFFLE_SIZE,
initial=_SHARD_SHUFFLE_INITIAL,
seed=args.seed,
epoch=shared_epoch,
),
wds.split_by_node,
wds.split_by_worker,
])
pipeline.extend([
# at this point, we have an iterator over the shards assigned to each worker at each node
tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
bufsize=_SAMPLE_SHUFFLE_SIZE,
initial=_SAMPLE_SHUFFLE_INITIAL,
),
])
else:
pipeline.extend([
wds.split_by_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
json_fetch,
wds.rename(image="jpg;png;jpeg;webp", text="txt;json"),
wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]),
wds.to_tuple("image", "text"),
wds.batched(args.batch_size, partial=not is_train)
])
dataset = wds.DataPipeline(*pipeline)
if is_train:
if not resampled:
num_shards = num_shards or len(expand_urls(input_shards)[0])
assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
# roll over and repeat a few samples to get same number of full batches on each node
round_fn = math.floor if floor else math.ceil
global_batch_size = args.batch_size * args.world_size
num_batches = round_fn(num_samples / global_batch_size)
num_workers = max(1, args.workers)
num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
num_batches = num_worker_batches * num_workers
num_samples = num_batches * global_batch_size
dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
else:
# last batches are partial, eval is done on single (master) node
num_batches = math.ceil(num_samples / args.batch_size)
dataloader = wds.WebLoader(
dataset,
batch_size=None,
shuffle=False,
num_workers=args.workers,
persistent_workers=args.workers > 0,
)
# FIXME not clear which approach is better, with_epoch before vs after dataloader?
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169
# if is_train:
# # roll over and repeat a few samples to get same number of full batches on each node
# global_batch_size = args.batch_size * args.world_size
# num_batches = math.ceil(num_samples / global_batch_size)
# num_workers = max(1, args.workers)
# num_batches = math.ceil(num_batches / num_workers) * num_workers
# num_samples = num_batches * global_batch_size
# dataloader = dataloader.with_epoch(num_batches)
# else:
# # last batches are partial, eval is done on single (master) node
# num_batches = math.ceil(num_samples / args.batch_size)
# add meta-data to dataloader instance for convenience
dataloader.num_batches = num_batches
dataloader.num_samples = num_samples
dataloader.batch_size = args.batch_size
return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
input_filename = args.train_data if is_train else args.val_data
assert input_filename
dataset = CsvDataset(
input_filename,
preprocess_fn,
img_key=args.csv_img_key,
caption_key=args.csv_caption_key,
sep=args.csv_separator,
tokenizer=tokenizer
)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
class SyntheticDataset(Dataset):
def __init__(
self,
transform=None,
image_size=(224, 224),
caption="Dummy caption",
dataset_size=100,
tokenizer=None,
):
self.transform = transform
self.image_size = image_size
self.caption = caption
self.image = Image.new('RGB', image_size)
self.dataset_size = dataset_size
self.preprocess_txt = lambda text: tokenizer(text)[0]
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
if self.transform is not None:
image = self.transform(self.image)
return image, self.preprocess_txt(self.caption)
def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):
image_size = preprocess_fn.transforms[0].size
dataset = SyntheticDataset(
transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer)
num_samples = len(dataset)
sampler = DistributedSampler(dataset) if args.distributed and is_train else None
shuffle = is_train and sampler is None
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.workers,
pin_memory=True,
sampler=sampler,
drop_last=is_train,
)
dataloader.num_samples = num_samples
dataloader.num_batches = len(dataloader)
return DataInfo(dataloader, sampler)
def get_dataset_fn(data_path, dataset_type):
if dataset_type == "webdataset":
return get_wds_dataset
elif dataset_type == "csv":
return get_csv_dataset
elif dataset_type == "synthetic":
return get_synthetic_dataset
elif dataset_type == "auto":
ext = data_path.split('.')[-1]
if ext in ['csv', 'tsv']:
return get_csv_dataset
elif ext in ['tar']:
return get_wds_dataset
else:
raise ValueError(
f"Tried to figure out dataset type, but failed for extension {ext}.")
else:
raise ValueError(f"Unsupported dataset type: {dataset_type}")
def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
preprocess_train, preprocess_val = preprocess_fns
data = {}
if args.train_data or args.dataset_type == "synthetic":
data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
if args.val_data:
data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
args, preprocess_val, is_train=False, tokenizer=tokenizer)
if args.imagenet_val is not None:
data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
if args.imagenet_v2 is not None:
data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
return data
================================================
FILE: inf_clip/train/engine.py
================================================
import json
import logging
import math
import os
import time
from contextlib import nullcontext
import numpy as np
import pynvml
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as torch_checkpoint
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel
from tqdm import tqdm
from .utils import get_autocast, is_master
from inf_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \
IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, CLIP, CustomTextCLIP
from inf_clip.models.loss import ClipLoss
try:
import wandb
except ImportError:
wandb = None
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
def get_clip_metrics(image_features, text_features, logit_scale):
metrics = {}
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
logits_per_text = logits_per_image.t().detach().cpu()
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
ground_truth = torch.arange(len(text_features)).view(-1, 1)
for name, logit in logits.items():
ranking = torch.argsort(logit, descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.detach().cpu().numpy()
metrics[f"{name}_mean_rank"] = preds.mean() + 1
metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"{name}_R@{k}"] = np.mean(preds < k)
return metrics
def maybe_compute_generative_loss(model_out):
if "logits" in model_out and "labels" in model_out:
token_logits = model_out["logits"]
token_labels = model_out["labels"]
return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)
def get_memory():
pynvml.nvmlInit()
# NOTE: 0 denotes GPU index.
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
return meminfo.used / 1024**3
def seconds_to_hms(seconds):
hours, remainder = divmod(seconds, 3600)
minutes, seconds = divmod(remainder, 60)
hours = int(hours); minutes = int(minutes); seconds = int(seconds)
return f"{hours}:{minutes:02d}:{seconds:02d}"
def cal_grad_norm(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
def assign_learning_rate(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def const_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
lr = base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):
def _lr_adjuster(step):
start_cooldown_step = steps - cooldown_steps
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
if step < start_cooldown_step:
lr = base_lr
else:
e = step - start_cooldown_step
es = steps - start_cooldown_step
# linear decay if power == 1; polynomial decay otherwise;
decay = (1 - (e/es)) ** cooldown_power
lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def cosine_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
"text_features": model_out[1],
"logit_scale": model_out[2]
}
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class GradientAccum:
def __init__(self, model, loss, scaler, autocast, input_dtype, device):
self.model = model
self.loss = loss
self.scaler = scaler
self.autocast = autocast
self.input_dtype = input_dtype
self.device = device
self.logit_scale = unwrap_model(model).logit_scale
self.arch_type = unwrap_model(model).arch_type
self.accum_freq = 0
self.accum_cpu_states = []
self.accum_gpu_devices_states = []
self.accum_images = []
self.accum_texts = []
self.accum_image_features = []
self.accum_text_features = []
self.rank = dist.get_rank()
def clear(self):
self.accum_image_features.clear()
self.accum_text_features.clear()
torch.cuda.empty_cache()
def clear_state(self):
self.accum_images.clear()
self.accum_texts.clear()
self.accum_cpu_states.clear()
self.accum_gpu_devices_states.clear()
self.accum_freq = 0
@torch.no_grad()
def accum_inference(self, images, texts):
images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)
texts = texts.to(device=self.device, non_blocking=True)
# First, cache the features without any gradient tracking.
with self.autocast():
# collect rand states
self.accum_cpu_states.append(torch.get_rng_state())
self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts]))
model_out = self.model(images, texts)
self.accum_image_features.append(model_out["image_features"].detach().clone())
self.accum_text_features.append(model_out["text_features"].detach().clone())
if self.arch_type == "lit":
# lit
accum_image = model_out["image_trunk_features"].detach().clone()
else:
accum_image = images.detach().clone()
accum_text = texts.detach().clone()
# offloading
accum_image = accum_image.cpu()
accum_text = accum_text.cpu()
self.accum_images.append(accum_image)
self.accum_texts.append(accum_texts)
self.accum_freq += 1
def accum_forward_backward(self):
accum_losses = {"loss": 0.0}
for j in range(self.accum_freq):
images = self.accum_images[j]
texts = self.accum_texts[j]
# refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235
# DDP will sync gradients across GPUs, which is no need except the last batch.
sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext
with torch.random.fork_rng(devices=(device,)), sync_context():
# setting random states
torch.set_rng_state(self.accum_cpu_states[j])
torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j])
with autocast():
model_out = self.model(images, texts)
inputs_no_accum = {}
inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale")
if "logit_bias" in model_out:
inputs_no_accum["logit_bias"] = model_out.pop("logit_bias")
inputs = {}
inputs["image_features"] = torch.cat(self.accum_image_features[:j] + [model_out["image_features"]] + self.accum_image_features[j + 1:])
inputs["text_features"] = torch.cat(self.accum_text_features[:j] + [model_out["text_features"]] + self.accum_text_features[j + 1:])
losses = self.loss(**inputs, **inputs_no_accum)
show_loss = losses.pop("show_loss")
total_loss = sum(losses.values())
losses["loss"] = show_loss
del inputs
del inputs_no_accum
backward(total_loss, scaler)
accum_losses["loss"] += losses["loss"]
accum_losses["loss"] /= accum_freq
self.clear()
self.clear_state()
return accum_losses
class GradientCache:
def __init__(self, model, loss, scaler, autocast, input_dtype, device):
self.model = model
self.loss = loss
self.scaler = scaler
self.autocast = autocast
self.input_dtype = input_dtype
self.device = device
self.logit_scale = unwrap_model(model).logit_scale
self.arch_type = unwrap_model(model).arch_type
self.accum_freq = 0
self.accum_cpu_states = []
self.accum_gpu_devices_states = []
self.accum_images = []
self.accum_texts = []
self.accum_image_features = []
self.accum_text_features = []
self.rank = dist.get_rank()
def clear(self):
self.accum_image_features.clear()
self.accum_text_features.clear()
torch.cuda.empty_cache()
def clear_state(self):
self.accum_images.clear()
self.accum_texts.clear()
self.accum_cpu_states.clear()
self.accum_gpu_devices_states.clear()
self.accum_freq = 0
def forward_backward(self, images, texts):
images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)
texts = texts.to(device=self.device, non_blocking=True)
with self.autocast():
model_out = self.model(image=images, text=texts)
model_out.pop("image_trunk_features", None)
losses = self.loss(**model_out)
show_loss = losses.pop("show_loss")
total_loss = sum(losses.values())
losses["loss"] = show_loss
backward(total_loss, self.scaler)
return losses
@torch.no_grad()
def accum_inference(self, images, texts):
images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)
texts = texts.to(device=self.device, non_blocking=True)
# First, cache the features without any gradient tracking.
with self.autocast():
# collect rand states
self.accum_cpu_states.append(torch.get_rng_state())
self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts]))
model_out = self.model(image=images, text=texts)
self.accum_image_features.append(model_out["image_features"])
self.accum_text_features.append(model_out["text_features"])
# Speed analysis of detach().clone(): https://stackoverflow.com/questions/55266154/pytorch-preferred-way-to-copy-a-tensor
if self.arch_type == "lit":
# lit
accum_image = model_out["image_trunk_features"].detach().clone()
else:
accum_image = images.detach().clone()
accum_text = texts.detach().clone()
# offloading
# accum_image = accum_image.cpu()
# accum_text = accum_text.cpu()
self.accum_images.append(accum_image)
self.accum_texts.append(accum_text)
self.accum_freq += 1
def accum_forward_backward(self):
accum_qs = [x.requires_grad_() for x in self.accum_image_features]; qs = torch.cat(accum_qs, dim=0)
accum_ks = [x.requires_grad_() for x in self.accum_text_features]; ks = torch.cat(accum_ks, dim=0)
ls = self.logit_scale.exp().detach().clone().requires_grad_()
losses = self.loss(image_features=qs, text_features=ks, logit_scale=ls)
show_loss = losses.pop("show_loss")
total_loss = sum(losses.values())
losses["loss"] = show_loss
backward(total_loss, self.scaler)
accum_q_grads = [q.grad for q in accum_qs]
accum_k_grads = [k.grad for k in accum_ks]
l_grad = ls.grad
del accum_qs, accum_ks
del qs, ks, ls
# Clean trash memory from loss calculation or inference
self.clear()
for j in range(self.accum_freq):
images = self.accum_images[j]
texts = self.accum_texts[j]
# refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235
# DDP will sync gradients across GPUs, which is no need except the last batch.
sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext
with torch.random.fork_rng(devices=(self.device, )), sync_context():
# setting random states
torch.set_rng_state(self.accum_cpu_states[j])
torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j])
with self.autocast():
if self.arch_type == "lit":
model_out = self.model(images, texts, project_only=True)
else:
model_out = self.model(images, texts)
q = model_out["image_features"]
k = model_out["text_features"]
l = model_out["logit_scale"]
_loss = torch.dot(q.flatten(), accum_q_grads[j].flatten()) + \
torch.dot(k.flatten(), accum_k_grads[j].flatten()) + \
l * l_grad / self.accum_freq
_loss.backward()
self.clear_state()
return losses
def train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
model.train()
data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
runner = GradientCache(model, loss, scaler, autocast, input_dtype, device)
rest_iters = num_batches_per_epoch * (args.epochs - epoch)
losses_m = {}
global_batch_time_m = AverageMeter()
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
i_accum = i // args.accum_freq
step = num_batches_per_epoch * epoch + i_accum
if not args.skip_scheduler:
scheduler(step)
images, texts = batch
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
losses = runner.forward_backward(images, texts)
else:
runner.accum_inference(images, texts)
# If (i + 1) % accum_freq is not zero, move on to the next batch.
if ((i + 1) % args.accum_freq) > 0:
# FIXME this makes data time logging unreliable when accumulating
continue
# Now, ready to take gradients for the last accum_freq batches.
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
# Call backwards each time, but only step optimizer at the end.
losses = runner.accum_forward_backward()
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
global_batch_time_m.update(time.time() - end)
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i_accum + 1
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
losses_m[key] = AverageMeter()
losses_m[key].update(val.item(), batch_size)
logit_scale_scalar = unwrap_model(model).logit_scale.exp().item()
loss_log = " ".join(
[
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
for loss_name, loss_m in losses_m.items()
]
)
samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
grad_norm = cal_grad_norm(model.module)
running_time = seconds_to_hms(time.time() - start_timestamp)
rest_iters = rest_iters - 1
whole_time = seconds_to_hms(time.time() - start_timestamp + rest_iters * global_batch_time_m.avg)
logging.info(
f"{running_time}<{whole_time} "
f"Epoch: {epoch + percent_complete:.2f} "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f} "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Grad Norm: {grad_norm:.3f} "
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + " "
f"Memory: {get_memory():.2f}GB "
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"scale": logit_scale_scalar,
"grad_norm": grad_norm,
"lr": optimizer.param_groups[0]["lr"]
}
log_data.update({name:val.val for name,val in losses_m.items()})
log_data = {"train/" + name: val for name, val in log_data.items()}
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
log_data['step'] = step # for backwards compatibility
wandb.log(log_data, step=step)
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for
def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
metrics = {}
if not is_master(args):
return metrics
device = torch.device(args.device)
model.eval()
zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer)
metrics.update(zero_shot_metrics)
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
dataloader = data['val'].dataloader
num_samples = 0
samples_per_val = dataloader.num_samples
# FIXME this does not scale past small eval datasets
# all_image_features @ all_text_features will blow up memory and compute very quickly
cumulative_loss = 0.0
cumulative_gen_loss = 0.0
all_image_features, all_text_features = [], []
with torch.inference_mode():
for i, batch in enumerate(dataloader):
images, texts = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
with autocast():
model_out = model(images, texts)
image_features = model_out["image_features"]
text_features = model_out["text_features"]
logit_scale = model_out["logit_scale"]
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features.append(image_features.cpu())
all_text_features.append(text_features.cpu())
logit_scale = logit_scale.mean()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
batch_size = images.shape[0]
labels = torch.arange(batch_size, device=device).long()
total_loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
gen_loss = maybe_compute_generative_loss(model_out)
cumulative_loss += total_loss * batch_size
num_samples += batch_size
if is_master(args) and (i % 100) == 0:
logging.info(
f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
f"Clip Loss: {cumulative_loss / num_samples:.6f}\t")
if gen_loss is not None:
cumulative_gen_loss += gen_loss * batch_size
logging.info(
f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t")
val_metrics = get_clip_metrics(
image_features=torch.cat(all_image_features),
text_features=torch.cat(all_text_features),
logit_scale=logit_scale.cpu(),
)
loss = cumulative_loss / num_samples
metrics.update(
{**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
)
if gen_loss is not None:
gen_loss = cumulative_gen_loss / num_samples
metrics.update({"val_generative_loss": gen_loss.item()})
if not metrics:
return metrics
logging.info(
f"Eval Epoch: {epoch} "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
log_data = {"val/" + name: val for name, val in metrics.items()}
if args.save_logs:
if tb_writer is not None:
for name, val in log_data.items():
tb_writer.add_scalar(name, val, epoch)
with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
f.write(json.dumps(metrics))
f.write("\n")
if args.wandb:
assert wandb is not None, 'Please install wandb.'
if 'train' in data:
dataloader = data['train'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
step = num_batches_per_epoch * epoch
else:
step = None
log_data['epoch'] = epoch
wandb.log(log_data, step=step)
return metrics
def zero_shot_run(model, classifier, dataloader, args):
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
with torch.inference_mode():
top1, top5, n = 0., 0., 0.
for images, target in tqdm(dataloader, unit_scale=args.batch_size):
images = images.to(device=args.device, dtype=input_dtype)
target = target.to(args.device)
with autocast():
# predict
output = model(image=images)
image_features = output['image_features'] if isinstance(output, dict) else output[0]
logits = 100. * image_features @ classifier
# measure accuracy
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
top1 += acc1
top5 += acc5
n += images.size(0)
top1 = (top1 / n)
top5 = (top5 / n)
return top1, top5
def zero_shot_eval(model, data, epoch, args, tokenizer=None):
if 'imagenet-val' not in data and 'imagenet-v2' not in data:
return {}
if args.zeroshot_frequency == 0:
return {}
if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
return {}
if args.distributed and not args.horovod:
model = model.module
logging.info('Starting zero-shot imagenet.')
if tokenizer is None:
tokenizer = get_tokenizer(args.model)
logging.info('Building zero-shot classifier')
autocast = get_autocast(args.precision)
with autocast():
classifier = build_zero_shot_classifier(
model,
tokenizer=tokenizer,
classnames=IMAGENET_CLASSNAMES,
templates=OPENAI_IMAGENET_TEMPLATES,
num_classes_per_batch=10,
device=args.device,
use_tqdm=True,
)
logging.info('Using classifier')
results = {}
if 'imagenet-val' in data:
top1, top5 = zero_shot_run(model, classifier, data['imagenet-val'].dataloader, args)
results['imagenet-zeroshot-val-top1'] = top1
results['imagenet-zeroshot-val-top5'] = top5
if 'imagenet-v2' in data:
top1, top5 = zero_shot_run(model, classifier, data['imagenet-v2'].dataloader, args)
results['imagenetv2-zeroshot-val-top1'] = top1
results['imagenetv2-zeroshot-val-top5'] = top5
logging.info('Finished zero-shot imagenet.')
return results
================================================
FILE: inf_clip/train/main.py
================================================
import glob
import logging
import os
import re
import subprocess
import sys
import random
import time
from functools import partial
import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler
try:
import wandb
except ImportError:
wandb = None
try:
import torch.utils.tensorboard as tensorboard
except ImportError:
tensorboard = None
try:
import horovod.torch as hvd
except ImportError:
hvd = None
from inf_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss
from inf_clip.train.data import get_data
from inf_clip.train.params import parse_args
from inf_clip.train.optims import ScalingViTAdafactor, Lion
from inf_clip.train.engine import cosine_lr, const_lr, const_lr_cooldown, train_one_epoch, evaluate
from inf_clip.train.utils import (setup_logging, pt_load, check_exists, start_sync_process, remote_sync, is_master, init_distributed_device, broadcast_object)
LATEST_CHECKPOINT_NAME = "epoch_latest.pt"
def random_seed(seed=42, rank=0):
random.seed(seed + rank)
np.random.seed(seed + rank)
torch.manual_seed(seed + rank)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def copy_codebase(args):
from shutil import copytree, ignore_patterns
new_code_path = os.path.join(args.log_dir, args.name, "code")
if os.path.exists(new_code_path):
print(
f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
)
return -1
print(f"Copying codebase to {new_code_path}")
current_code_path = os.path.realpath(__file__)
for _ in range(3):
current_code_path = os.path.dirname(current_code_path)
copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))
print("Done copying code.")
return 1
def prepare_logging(args):
# get the name of the experiments
if args.name is None:
from datetime import datetime
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
model_name_safe = args.model.replace('/', '-')
date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S")
if args.distributed:
# sync date_str from master to all ranks
date_str = broadcast_object(args, date_str)
args.name = '-'.join([
date_str,
f"model_{model_name_safe}",
f"lr_{args.lr}",
f"b_{args.batch_size}",
f"j_{args.workers}",
f"p_{args.precision}",
])
resume_latest = args.resume == 'latest'
log_base_path = os.path.join(args.log_dir, args.name)
args.log_path = None
if is_master(args, local=args.log_local):
os.makedirs(log_base_path, exist_ok=True)
log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
args.log_path = os.path.join(log_base_path, log_filename)
# if os.path.exists(args.log_path) and not resume_latest:
# print(
# "Error. Experiment already exists. Use --name {} to specify a new experiment."
# )
# return -1
# Setup text logger
args.log_level = logging.DEBUG if args.debug else logging.INFO
setup_logging(args.log_path, args.log_level)
# Setup wandb, tensorboard, checkpoint logging
args.wandb = 'wandb' in args.report_to or 'all' in args.report_to
args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
if is_master(args):
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ''
for dirname in [args.tensorboard_path, args.checkpoint_path]:
if dirname:
os.makedirs(dirname, exist_ok=True)
else:
args.tensorboard_path = ''
if args.copy_codebase:
copy_codebase(args)
return log_base_path, resume_latest
def get_latest_checkpoint(path: str, remote : bool):
# as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
if remote:
result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
print(result)
if result.returncode == 1:
return None
checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]]
else:
checkpoints = glob.glob(path + '**/*.pt', recursive=True)
if checkpoints:
checkpoints = sorted(checkpoints, key=natural_key)
return checkpoints[-1]
return None
def prepare_resuming(args):
resume_from = None
checkpoint_path = args.checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if args.remote_sync is not None:
checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints")
if args.save_most_recent:
print('Error. Cannot use save-most-recent with remote_sync and resume latest.')
return -1
if args.remote_sync_protocol != 's3':
print('Error. Sync protocol not supported when using resume latest.')
return -1
if is_master(args):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
# stress, however it's very difficult to fully work around such situations.
if args.save_most_recent:
# if --save-most-recent flag is set, look for latest at a fixed filename
resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)
if not os.path.exists(resume_from):
# If no latest checkpoint has been saved yet, don't try to resume
resume_from = None
else:
# otherwise, list checkpoint dir contents and pick the newest checkpoint
resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)
if resume_from:
logging.info(f'Found latest resume checkpoint at {resume_from}.')
else:
logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')
if args.distributed:
# sync found checkpoint path to all ranks
resume_from = broadcast_object(args, resume_from)
args.resume = resume_from
def prepare_remote_sync(args):
# start the sync proces if remote-sync is not None
remote_sync_process = None
if is_master(args) and args.remote_sync is not None:
# first make sure it works
result = remote_sync(
os.path.join(args.log_dir, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('remote sync successful.')
else:
logging.info('Error: remote sync failed. Exiting.')
return -1
# if all looks good, start a process to do this every args.remote_sync_frequency seconds
remote_sync_process = start_sync_process(
args.remote_sync_frequency,
os.path.join(args.log_dir, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
remote_sync_process.start()
return remote_sync_process
def prepare_model(args, device):
dist_model = None
args.distill = args.distill_model is not None and args.distill_pretrained is not None
if args.distill:
#FIXME: support distillation with grad accum.
assert args.accum_freq == 1
#FIXME: support distillation with coca.
assert 'coca' not in args.model.lower()
if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1:
# arg is nargs, single (square) image size list -> int
args.force_image_size = args.force_image_size[0]
random_seed(args.seed, 0)
model_kwargs = {}
if args.siglip:
model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP
model_kwargs['init_logit_bias'] = -10
model, preprocess_train, preprocess_val = create_model_and_transforms(
args.model,
args.pretrained,
precision=args.precision,
device=device,
jit=args.torchscript,
force_quick_gelu=args.force_quick_gelu,
force_custom_text=args.force_custom_text,
force_patch_dropout=args.force_patch_dropout,
force_image_size=args.force_image_size,
image_mean=args.image_mean,
image_std=args.image_std,
image_interpolation=args.image_interpolation,
image_resize_mode=args.image_resize_mode, # only effective for inference
aug_cfg=args.aug_cfg,
pretrained_image=args.pretrained_image,
output_dict=True,
**model_kwargs,
)
if args.distill:
# FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms.
dist_model, _, _ = create_model_and_transforms(
args.distill_model,
args.distill_pretrained,
device=device,
precision=args.precision,
output_dict=True,
)
if args.use_bnb_linear is not None:
print('=> using a layer from bitsandbytes.\n'
' this is an experimental feature which requires two extra pip installs\n'
' pip install bitsandbytes triton'
' please make sure to use triton 2.0.0')
import bitsandbytes as bnb
from open_clip.utils import replace_linear
print(f'=> replacing linear layers with {args.use_bnb_linear}')
linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear)
replace_linear(model, linear_replacement_cls)
model = model.to(device)
random_seed(args.seed, args.rank)
if args.trace:
model = trace_model(model, batch_size=args.batch_size, device=device)
if args.lock_image:
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
model.lock_image_tower(
unlocked_groups=args.lock_image_unlocked_groups,
freeze_bn_stats=args.lock_image_freeze_bn_stats)
if args.lock_text:
model.lock_text_tower(
unlocked_layers=args.lock_text_unlocked_layers,
freeze_layer_norm=args.lock_text_freeze_layer_norm)
if args.grad_checkpointing:
model.set_grad_checkpointing()
if args.distributed and not args.horovod:
if args.use_bn_sync:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args['static_graph'] = True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)
if args.distill:
dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args)
if is_master(args):
logging.info("Model:")
logging.info(f"{str(model)}")
logging.info("Params:")
params_file = os.path.join(args.log_dir, args.name, "params.txt")
with open(params_file, "w") as f:
for name in sorted(vars(args)):
val = getattr(args, name)
logging.info(f" {name}: {val}")
f.write(f"{name}: {val}\n")
tokenizer = get_tokenizer(args.model)
return tokenizer, model, dist_model, preprocess_train, preprocess_val
def prepare_optimizer_scaler(args, model):
assert not args.trace, 'Cannot train with traced model'
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
include = lambda n, p: not exclude(n, p)
named_parameters = list(model.named_parameters())
named_parameters = [(n, p) for n, p in named_parameters if p.requires_grad]
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
if args.optimizer == "adam":
optimizer = optim.AdamW(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
)
elif args.optimizer == "adafactor":
optimizer = ScalingViTAdafactor(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
beta1=args.beta1,
beta2=args.beta2,
)
elif args.optimizer == "lion":
optimizer = Lion(
[
{"params": gain_or_bias_params, "weight_decay": 0.},
{"params": rest_params, "weight_decay": args.wd},
],
lr=args.lr,
betas=(args.beta1, args.beta2),
)
# elif args.optim == "lamb":
if args.horovod:
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
scaler = GradScaler() if args.precision == "amp" else None
return optimizer, scaler
def prepare_scheduler(args, optimizer, num_batches):
scheduler = None
total_steps = (num_batches // args.accum_freq) * args.epochs
if args.lr_scheduler == "cosine":
scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
elif args.lr_scheduler == "const":
scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps)
elif args.lr_scheduler == "const-cooldown":
assert args.epochs_cooldown is not None,\
"Please specify the number of cooldown epochs for this lr schedule."
cooldown_steps = (num_batches // args.accum_freq) * args.epochs_cooldown
scheduler = const_lr_cooldown(
optimizer, args.lr, args.warmup, total_steps,
cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end)
else:
logging.error(
f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.')
exit(1)
return scheduler
def main(args):
args = parse_args(args)
if torch.cuda.is_available():
# This enables tf32 on Ampere GPUs which is only 8% slower than
# float16 and almost as accurate as float32
# This was a default in pytorch until 1.12
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# fully initialize distributed device environment
device = init_distributed_device(args)
log_base_path, resume_latest = prepare_logging(args)
-1 if not resume_latest else prepare_resuming(args)
remote_sync_process = prepare_remote_sync(args)
if args.horovod:
logging.info(
f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
elif args.distributed:
logging.info(
f'Running in distributed mode with multiple processes. Device: {args.device}.'
f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')
else:
logging.info(f'Running with a single process. Device {args.device}.')
tokenizer, model, dist_model, preprocess_train, preprocess_val = prepare_model(args, device)
# create optimizer and scaler
if args.train_data or args.dataset_type == "synthetic":
optimizer, scaler = prepare_optimizer_scaler(args, model)
# optionally resume from a checkpoint
start_epoch = 0
if args.resume is not None:
checkpoint = pt_load(args.resume, map_location='cpu')
if 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if not args.distributed and next(iter(sd.items()))[0].startswith('module'):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
if optimizer is not None:
optimizer.load_state_dict(checkpoint["optimizer"])
if scaler is not None and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
# initialize datasets
data = get_data(
args,
(preprocess_train, preprocess_val),
epoch=start_epoch,
tokenizer=tokenizer,
)
assert len(data), 'At least one train or eval dataset must be specified.'
# create scheduler if train
if 'train' in data and optimizer is not None:
scheduler = prepare_scheduler(args, optimizer, data["train"].dataloader.num_batches)
# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.log_dir and args.log_dir.lower() != 'none' and is_master(args)
writer = None
if args.save_logs and args.tensorboard:
assert tensorboard is not None, "Please install tensorboard."
writer = tensorboard.SummaryWriter(args.tensorboard_path)
if args.wandb and is_master(args):
assert wandb is not None, 'Please install wandb.'
logging.debug('Starting wandb.')
args.train_sz = data["train"].dataloader.num_samples
if args.val_data is not None:
args.val_sz = data["val"].dataloader.num_samples
# you will have to configure this for your project!
wandb.init(
project=args.wandb_project_name,
name=args.name,
id=args.name,
notes=args.wandb_notes,
tags=[],
resume='auto' if args.resume == "latest" else None,
config=vars(args),
)
if args.debug:
wandb.watch(model, log='all')
wandb.save(params_file)
logging.debug('Finished loading wandb.')
# Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models.
# For compatibility, we save state_dict() of the original model, which shares the
# weights without the prefix.
original_model = model
if args.torchcompile:
logging.info('Compiling model...')
model = torch.compile(original_model)
if 'train' not in data:
# If using int8, convert to inference mode.
if args.use_bnb_linear is not None:
from open_clip.utils import convert_int8_model_to_inference_mode
convert_int8_model_to_inference_mode(model)
# Evaluate.
evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer)
return
loss = create_loss(args)
start_timestamp = time.time()
for epoch in range(start_epoch, args.epochs):
if is_master(args):
logging.info(f'Start epoch {epoch}')
train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)
completed_epoch = epoch + 1
if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):
evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer)
# Saving checkpoints.
if args.save_logs:
checkpoint_dict = {
"epoch": completed_epoch,
"name": args.name,
"state_dict": original_model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if scaler is not None:
checkpoint_dict["scaler"] = scaler.state_dict()
if completed_epoch == args.epochs or (
args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
):
torch.save(
checkpoint_dict,
os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
)
if args.delete_previous_checkpoint:
previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt")
if os.path.exists(previous_checkpoint):
os.remove(previous_checkpoint)
if args.save_most_recent:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt")
latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)
torch.save(checkpoint_dict, tmp_save_path)
os.replace(tmp_save_path, latest_save_path)
if args.wandb and is_master(args):
wandb.finish()
# run a final sync.
if remote_sync_process is not None:
logging.info('Final remote sync.')
remote_sync_process.terminate()
result = remote_sync(
os.path.join(args.log_dir, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol
)
if result:
logging.info('Final remote sync successful.')
else:
logging.info('Final remote sync failed.')
if __name__ == "__main__":
main(sys.argv[1:])
================================================
FILE: inf_clip/train/optims.py
================================================
import math
import torch
from torch import nn
from torch.optim import Optimizer
class ScalingViTAdafactor(Optimizer):
"""
Modified version of Adafactor in Transformers https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/optimization.py#L672,
which refers to Paper: *Scaling Vision Transformers* https://arxiv.org/pdf/2106.04560
1. Re-introducing the first momentum in half-precision.
2. Disable scaling of learning rate relative to weight norms, a feature that is part of Adafactor.
3. Clipping the second momentum at 0.999
"""
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=0.9,
beta2=0.999,
weight_decay=0.0,
scale_parameter=False,
relative_step=False,
warmup_init=False,
):
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"beta2": beta2,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
# NOTE: gradient keep in float32
# if grad.dtype in {torch.float16, torch.bfloat16}:
# grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# NOTE: using bfloat16 for first momentum
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad).bfloat16()
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
# NOTE: keep in float32
# if p.dtype in {torch.float16, torch.bfloat16}:
# p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
beta2t = min(beta2t, group["beta2"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
return loss
class Lion(Optimizer):
"""
Modified version of Lion in https://github.com/google/automl/blob/master/lion/lion_pytorch.py,
which refers to Paper: *Symbolic Discovery of Optimization Algorithms* https://arxiv.org/pdf/2302.06675
"""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
"""Initialize the hyperparameters.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-4)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float, optional): weight decay coefficient (default: 0)
"""
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
beta1, beta2 = group['betas']
# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
p.add_(update.sign_(), alpha=-group['lr'])
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
return loss
================================================
FILE: inf_clip/train/params.py
================================================
import os
import argparse
import ast
import json
from .utils import world_info_from_env
def get_default_params(model_name):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
model_name = model_name.lower()
if "vit" in model_name:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
else:
return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--train-data",
type=str,
default=None,
help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.",
)
parser.add_argument(
"--train-data-upsampling-factors",
type=str,
default=None,
help=(
"When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. "
"Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) "
"By default, datapoints are sampled uniformly regardless of the dataset sizes."
)
)
parser.add_argument(
"--val-data",
type=str,
default=None,
help="Path to file(s) with validation data",
)
parser.add_argument(
"--train-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Required for webdataset if not available in info file.",
)
parser.add_argument(
"--val-num-samples",
type=int,
default=None,
help="Number of samples in dataset. Useful for webdataset if not available in info file.",
)
parser.add_argument(
"--dataset-type",
choices=["webdataset", "csv", "synthetic", "auto"],
default="auto",
help="Which type of dataset to process."
)
parser.add_argument(
"--dataset-resampled",
default=False,
action="store_true",
help="Whether to use sampling with replacement for webdataset shard selection."
)
parser.add_argument(
"--csv-separator",
type=str,
default="\t",
help="For csv-like datasets, which separator to use."
)
parser.add_argument(
"--csv-img-key",
type=str,
default="filepath",
help="For csv-like datasets, the name of the key for the image paths."
)
parser.add_argument(
"--csv-caption-key",
type=str,
default="title",
help="For csv-like datasets, the name of the key for the captions."
)
parser.add_argument(
"--imagenet-val",
type=str,
default=None,
help="Path to imagenet val set for conducting zero shot evaluation.",
)
parser.add_argument(
"--imagenet-v2",
type=str,
default=None,
help="Path to imagenet v2 for conducting zero shot evaluation.",
)
parser.add_argument(
"--log_dir",
type=str,
default="./logs/",
help="Where to store tensorboard logs. Use None to avoid storing logs.",
)
parser.add_argument(
"--log-local",
action="store_true",
default=False,
help="log files on local master, otherwise global master only.",
)
parser.add_argument(
"--name",
type=str,
default=None,
help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
)
parser.add_argument(
"--workers", type=int, default=4, help="Number of dataloader workers per GPU."
)
parser.add_argument(
"--batch-size", type=int, default=64, help="Batch size per GPU."
)
parser.add_argument(
"--epochs", type=int, default=32, help="Number of epochs to train for."
)
parser.add_argument(
"--epochs-cooldown", type=int, default=None,
help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards."
)
parser.add_argument("--optimizer", type=str, default="adam", help="Optimizer to use.")
parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
parser.add_argument("--beta1", type=float, default=None, help="coefficient of moving average of first moment.")
parser.add_argument("--beta2", type=float, default=None, help="coefficient of moving average of second moment.")
parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
parser.add_argument(
"--warmup", type=int, default=10000, help="Number of steps to warmup for."
)
parser.add_argument(
"--use-bn-sync",
default=False,
action="store_true",
help="Whether to use batch norm sync.")
parser.add_argument(
"--skip-scheduler",
action="store_true",
default=False,
help="Use this flag to skip the learning rate decay.",
)
parser.add_argument(
"--lr-scheduler",
type=str,
default='cosine',
help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine",
)
parser.add_argument(
"--lr-cooldown-end", type=float, default=0.0,
help="End learning rate for cooldown schedule. Default: 0"
)
parser.add_argument(
"--lr-cooldown-power", type=float, default=1.0,
help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)"
)
parser.add_argument(
"--save-frequency", type=int, default=1, help="How often to save checkpoints."
)
parser.add_argument(
"--save-most-recent",
action="store_true",
default=False,
help="Always save the most recent model trained to epoch_latest.pt.",
)
parser.add_argument(
"--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
)
parser.add_argument(
"--val-frequency", type=int, default=1, help="How often to run evaluation with val data."
)
parser.add_argument(
"--resume",
default=None,
type=str,
help="path to latest checkpoint (default: none)",
)
parser.add_argument(
"--precision",
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"],
default="amp",
help="Floating point precision."
)
parser.add_argument(
"--model",
type=str,
default="RN50",
help="Name of the vision backbone to use.",
)
parser.add_argument(
"--pretrained",
default='',
type=str,
help="Use a pretrained CLIP model weights with the specified tag or file path.",
)
parser.add_argument(
"--pretrained-image",
default=False,
action='store_true',
help="Load imagenet pretrained weights for image tower backbone if available.",
)
parser.add_argument(
"--lock-image",
default=False,
action='store_true',
help="Lock full image tower by disabling gradients.",
)
parser.add_argument(
"--lock-image-unlocked-groups",
type=int,
default=0,
help="Leave last n image tower layer groups unlocked.",
)
parser.add_argument(
"--lock-image-freeze-bn-stats",
default=False,
action='store_true',
help="Freeze BatchNorm running stats in image tower for any locked layers.",
)
parser.add_argument(
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override default image mean value of dataset')
parser.add_argument(
'--image-std', type=float, nargs='+', default=None, metavar='STD',
help='Override default image std deviation of dataset')
parser.add_argument(
'--image-interpolation',
default=None, type=str, choices=['bicubic', 'bilinear', 'random'],
help="Override default image resize interpolation"
)
parser.add_argument(
'--image-resize-mode',
default=None, type=str, choices=['shortest', 'longest', 'squash'],
help="Override default image resize (& crop) mode during inference"
)
parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)
parser.add_argument(
"--grad-checkpointing",
default=False,
action='store_true',
help="Enable gradient checkpointing.",
)
parser.add_argument(
"--local-loss",
default=False,
action="store_true",
help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
)
parser.add_argument(
"--gather-with-grad",
default=False,
action="store_true",
help="enable full distributed gradient for feature gather"
)
parser.add_argument(
'--force-image-size', type=int, nargs='+', default=None,
help='Override default image size'
)
parser.add_argument(
"--force-quick-gelu",
default=False,
action='store_true',
help="Force use of QuickGELU activation for non-OpenAI transformer models.",
)
parser.add_argument(
"--force-patch-dropout",
default=None,
type=float,
help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper",
)
parser.add_argument(
"--force-custom-text",
default=False,
action='store_true',
help="Force use of CustomTextCLIP model (separate text-tower).",
)
parser.add_argument(
"--torchscript",
default=False,
action='store_true',
help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
)
parser.add_argument(
"--torchcompile",
default=False,
action='store_true',
help="torch.compile() the model, requires pytorch 2.0 or later.",
)
parser.add_argument(
"--trace",
default=False,
action='store_true',
help="torch.jit.trace the model for inference / eval only",
)
parser.add_argument(
"--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps."
)
# arguments for distributed training
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--report-to",
default='',
type=str,
help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
)
parser.add_argument(
"--wandb-notes",
default='',
type=str,
help="Notes if logging with wandb"
)
parser.add_argument(
"--wandb-project-name",
type=str,
default='open-clip',
help="Name of the project if logging with wandb.",
)
parser.add_argument(
"--debug",
default=False,
action="store_true",
help="If true, more information is logged."
)
parser.add_argument(
"--copy-codebase",
default=False,
action="store_true",
help="If true, we copy the entire base on the log directory, and execute from there."
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training."
)
parser.add_argument(
"--ddp-static-graph",
default=False,
action='store_true',
help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
)
parser.add_argument(
"--seed", type=int, default=0, help="Default random seed."
)
parser.add_argument(
"--grad-clip-norm", type=float, default=None, help="Gradient clip."
)
parser.add_argument(
"--lock-text",
default=False,
action='store_true',
help="Lock full text tower by disabling gradients.",
)
parser.add_argument(
"--lock-text-unlocked-layers",
type=int,
default=0,
help="Leave last n text tower layer groups unlocked.",
)
parser.add_argument(
"--lock-text-freeze-layer-norm",
default=False,
action='store_true',
help="Freeze LayerNorm running stats in text tower for any locked layers.",
)
parser.add_argument(
"--log-every-n-steps",
type=int,
default=100,
help="Log every n steps to tensorboard/console/wandb.",
)
parser.add_argument(
"--coca-caption-loss-weight",
type=float,
default=2.0,
help="Weight assigned to caption loss in CoCa."
)
parser.add_argument(
"--coca-contrastive-loss-weight",
type=float,
default=1.0,
help="Weight assigned to contrastive loss when training CoCa."
)
parser.add_argument(
"--remote-sync",
type=str,
default=None,
help="Optinoally sync with a remote path specified by this arg",
)
parser.add_argument(
"--remote-sync-frequency",
type=int,
default=300,
help="How frequently to sync to a remote directly if --remote-sync is not None.",
)
parser.add_argument(
"--remote-sync-protocol",
choices=["s3", "fsspec"],
default="s3",
help="How to do the remote sync backup if --remote-sync is not None.",
)
parser.add_argument(
"--delete-previous-checkpoint",
default=False,
action="store_true",
help="If true, delete previous checkpoint after storing a new one."
)
parser.add_argument(
"--distill-model",
default=None,
help='Which model arch to distill from, if any.'
)
parser.add_argument(
"--distill-pretrained",
default=None,
help='Which pre-trained weights to distill from, if any.'
)
parser.add_argument(
"--use-bnb-linear",
default=None,
help='Replace the network linear layers from the bitsandbytes library. '
'Allows int8 training/inference, etc.'
)
parser.add_argument(
"--siglip",
default=False,
action="store_true",
help='Use SigLip (sigmoid) loss.'
)
parser.add_argument(
"--flashloss",
default=False,
action="store_true",
help='Use flash loss.'
)
parser.add_argument(
"--ringloss",
default=False,
action="store_true",
help='Use ring loss.'
)
parser.add_argument(
"--infloss",
default=False,
action="store_true",
help='Use ring flash loss.'
)
parser.add_argument(
"--discoloss",
default=False,
action="store_true",
help='Use disc loss.'
)
try:
import deepspeed
parser = deepspeed.add_config_arguments(parser)
parser.add_argument('--zero-stage', type=int, default=1, help='stage of ZERO')
except:
print("Please 'pip install deepspeed==0.8.1'")
exit(0)
args = parser.parse_args(args)
if args.deepspeed:
create_deepspeed_config(args)
# If some params are not passed, we use the default values based on model name.
default_params = get_default_params(args.model)
for name, val in default_params.items():
if getattr(args, name) is None:
setattr(args, name, val)
return args
def create_deepspeed_config(args):
_, _, world_size = world_info_from_env()
args.deepspeed_config = os.path.join(os.getcwd(), "scripts", "deepspeed_config.json")
# default optimizer
optim_settings = None
if args.optimizer.lower() == "adamw":
optim_settings = {
"type": "Adam",
"adam_w_mode": True,
"params": {
"bias_correction": True,
"betas": [
args.beta1,
args.beta2
],
"eps": args.eps,
}
}
# LAMB
elif args.optimizer.lower() == "lamb":
# https://arxiv.org/pdf/1904.00962.pdf
optim_settings = {
"type": "LAMB",
"params": {
"bias_correction": True,
"betas": [
args.beta1,
args.beta2
],
"eps": args.eps,
"max_coeff": 10.0, #0.3
"min_coeff": 0.01,
"eps_inside_sqrt": False,
}
}
if args.optimizer.lower() == "1bitlamb":
# not supported
# 1bit-Lamb is not compatible with ZeRO; zero-stage should be 0
# https://arxiv.org/abs/2104.06069
optim_settings = {
"type": "OneBitLamb",
"params": {
"bias_correction": True,
"betas": [
args.beta1,
args.beta2
],
"eps": args.eps,
"max_coeff": 10.0, #0.3
"min_coeff": 0.01,
"eps_inside_sqrt": False,
"freeze_step": args.warmup,
# "comm_backend_name": "nccl",
# "coeff_beta": 0.9,
# "factor_max": 4.0,
# "factor_min": 0.5,
# "factor_threshold": 0.1
}
}
with open(args.deepspeed_config, mode="w") as writer:
ds_config = {
"train_batch_size": args.batch_size * world_size * args.accum_freq,
"train_micro_batch_size_per_gpu": args.batch_size,
"gradient_accumulation_steps": args.accum_freq,
"gradient_accumulation_dtype": "fp32",
"steps_per_print": 1000000,
"zero_allow_untested_optimizer": True,
"fp16": {
"enabled": True if args.precision != "bf16" else False,
# "auto_cast": True,
"loss_scale": 0,
"initial_scale_power": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": args.precision == "bf16"
},
"amp": {
"enabled": False,
"opt_level": "O2"
},
"flops_profiler": {
"enabled": True,
"profile_step": -1,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
},
"activation_checkpointing": {
"partition_activations": args.grad_checkpointing,
"contiguous_memory_optimization": False,
"profile": True
},
# "wallclock_breakdown": True
}
if optim_settings is not None:
ds_config.update({'optimizer': optim_settings})
if args.grad_clip_norm is not None:
ds_config.update({'gradient_clipping': args.grad_clip_norm})
if args.zero_stage == 1:
ds_config.update(
{
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8,
}
}
)
elif args.zero_stage == 2:
ds_config.update(
{
"zero_optimization": {
"stage": 2,
"contiguous_gradients": ('vit-b' not in args.model.lower()), # should be False if model is small,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8,
"cpu_offload": False
}
}
)
elif args.zero_stage == 3:
ds_config.update(
{
"zero_optimization": {
"stage": 3,
"contiguous_gradients": True,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 5e4,
"allgather_bucket_size": 5e4,
"cpu_offload": False,
},
"stage3_max_live_parameters": 1e5,
"stage3_max_reuse_distance": 1e5,
}
)
elif args.zero_stage > 3:
raise NotImplementedError()
writer.write(json.dumps(ds_config, indent=2))
================================================
FILE: inf_clip/train/utils.py
================================================
import os
import time
import logging
import subprocess
import multiprocessing
import torch
import torch.distributed as dist
import fsspec
from tqdm import tqdm
from contextlib import suppress
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def setup_logging(log_file, level, include_host=False):
if include_host:
import socket
hostname = socket.gethostname()
formatter = logging.Formatter(
f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
else:
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')
logging.root.setLevel(level)
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
for logger in loggers:
logger.setLevel(level)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logging.root.addHandler(stream_handler)
if log_file:
file_handler = logging.FileHandler(filename=log_file)
file_handler.setFormatter(formatter)
logging.root.addHandler(file_handler)
def remote_sync_s3(local_dir, remote_dir):
# skip epoch_latest which can change during sync.
result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}")
return False
logging.info(f"Successfully synced with S3 bucket")
return True
def remote_sync_fsspec(local_dir, remote_dir):
# FIXME currently this is slow and not recommended. Look into speeding up.
a = fsspec.get_mapper(local_dir)
b = fsspec.get_mapper(remote_dir)
for k in a:
# skip epoch_latest which can change during sync.
if 'epoch_latest.pt' in k:
continue
logging.info(f'Attempting to sync {k}')
if k in b and len(a[k]) == len(b[k]):
logging.debug(f'Skipping remote sync for {k}.')
continue
try:
logging.info(f'Successful sync for {k}.')
b[k] = a[k]
except Exception as e:
logging.info(f'Error during remote sync for {k}: {e}')
return False
return True
def remote_sync(local_dir, remote_dir, protocol):
logging.info('Starting remote sync.')
if protocol == 's3':
return remote_sync_s3(local_dir, remote_dir)
elif protocol == 'fsspec':
return remote_sync_fsspec(local_dir, remote_dir)
else:
logging.error('Remote protocol not known')
return False
def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):
while True:
time.sleep(sync_every)
remote_sync(local_dir, remote_dir, protocol)
def start_sync_process(sync_every, local_dir, remote_dir, protocol):
p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))
return p
# Note: we are not currently using this save function.
def pt_save(pt_obj, file_path):
of = fsspec.open(file_path, "wb")
with of as f:
torch.save(pt_obj, file_path)
def pt_load(file_path, map_location=None):
if file_path.startswith('s3'):
logging.info('Loading remote checkpoint, which may take a bit.')
of = fsspec.open(file_path, "rb")
with of as f:
out = torch.load(f, map_location=map_location)
return out
def check_exists(file_path):
try:
with fsspec.open(file_path):
pass
except FileNotFoundError:
return False
return True
def get_autocast(precision):
if precision == 'amp':
return lambda: torch.amp.autocast("cuda", dtype=torch.bfloat16)
elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.amp.autocast("cuda", dtype=torch.bfloat16)
else:
return suppress
##################################
# Distributed training utilities #
##################################
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
return True
else:
return False
def is_using_distributed():
if 'WORLD_SIZE' in os.environ:
return int(os.environ['WORLD_SIZE']) > 1
if 'SLURM_NTASKS' in os.environ:
return int(os.environ['SLURM_NTASKS']) > 1
return False
def world_info_from_env():
local_rank = 0
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if args.horovod:
assert hvd is not None, "Horovod is not installed"
hvd.init()
args.local_rank = int(hvd.local_rank())
args.rank = hvd.rank()
args.world_size = hvd.size()
args.distributed = True
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
elif is_using_distributed():
if 'SLURM_PROCID' in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ['LOCAL_RANK'] = str(args.local_rank)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'cuda:%d' % args.local_rank
else:
device = 'cuda:0'
torch.cuda.set_device(device)
else:
device = 'cpu'
args.device = device
device = torch.device(device)
return device
def broadcast_object(args, obj, src=0):
# broadcast a pickle-able python object from rank-0 to all ranks
if args.horovod:
return hvd.broadcast_object(obj, root_rank=src)
else:
if args.rank == src:
objects = [obj]
else:
objects = [None]
dist.broadcast_object_list(objects, src=src)
return objects[0]
def all_gather_object(args, obj, dst=0):
# gather a pickle-able python object across all ranks
if args.horovod:
return hvd.allgather_object(obj)
else:
objects = [None for _ in range(args.world_size)]
dist.all_gather_object(objects, obj)
return objects
================================================
FILE: inf_clip/utils.py
================================================
from itertools import repeat
import collections.abc
import torch
from torch import nn as nn
from torchvision.ops.misc import FrozenBatchNorm2d
def freeze_batch_norm_2d(module, module_match={}, name=''):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
is_match = True
if module_match:
is_match = name in module_match
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for child_name, child in module.named_children():
full_child_name = '.'.join([name, child_name]) if name else child_name
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
if new_child is not child:
res.add_module(child_name, new_child)
return res
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = lambda n, x: _ntuple(n)(x)
# Replaces all linear layers with linear_replacement
# TODO: add int8 support for other linear layers including attn and convnets
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
for name, module in model.named_children():
if len(list(module.children())) > 0:
replace_linear(module, linear_replacement, include_modules, copy_weights)
if isinstance(module, torch.nn.Linear) and name in include_modules:
old_module = model._modules[name]
model._modules[name] = linear_replacement(
module.in_features,
module.out_features,
module.bias is not None,
)
if copy_weights:
model._modules[name].weight.data.copy_(old_module.weight.data)
if model._modules[name].bias is not None:
model._modules[name].bias.data.copy_(old_module.bias)
return model
def convert_int8_model_to_inference_mode(model):
for m in model.modules():
if hasattr(m, 'prepare_for_eval'):
int8_original_dtype = m.weight.dtype
m.prepare_for_eval()
m.int8_original_dtype = int8_original_dtype
================================================
FILE: inf_clip/zero_shot_classifier.py
================================================
from functools import partial
from itertools import islice
from typing import Callable, List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
def build_zero_shot_classifier(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
num_classes_per_batch: Optional[int] = 10,
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names in batches
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
num_classes_per_batch: The number of classes to batch together in each forward, all if None
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
use_format = isinstance(templates[0], str)
num_templates = len(templates)
num_classes = len(classnames)
if use_tqdm:
import tqdm
num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)
iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)
else:
iter_wrap = iter
def _process_batch(batch_classnames):
num_batch_classes = len(batch_classnames)
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
texts = tokenizer(texts).to(device)
class_embeddings = model.encode_text(texts, normalize=True)
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
class_embeddings = class_embeddings.T
return class_embeddings
with torch.no_grad():
if num_classes_per_batch:
batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]
zeroshot_weights = torch.cat(batched_embeds, dim=1)
else:
zeroshot_weights = _process_batch(classnames)
return zeroshot_weights
def build_zero_shot_classifier_legacy(
model,
tokenizer,
classnames: Sequence[str],
templates: Sequence[Union[Callable, str]],
device: Union[str, torch.device] = 'cpu',
use_tqdm: bool = False,
):
""" Build zero-shot classifier weights by iterating over class names 1 by 1
Args:
model: CLIP model instance
tokenizer: CLIP tokenizer instance
classnames: A sequence of class (label) names
templates: A sequence of callables or format() friendly strings to produce templates per class name
device: Device to use.
use_tqdm: Enable TQDM progress bar.
"""
assert isinstance(templates, Sequence) and len(templates) > 0
assert isinstance(classnames, Sequence) and len(classnames) > 0
if use_tqdm:
import tqdm
iter_wrap = tqdm.tqdm
else:
iter_wrap = iter
use_format = isinstance(templates[0], str)
with torch.no_grad():
zeroshot_weights = []
for classname in iter_wrap(classnames):
texts = [template.format(classname) if use_format else template(classname) for template in templates]
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
================================================
FILE: inf_clip/zero_shot_metadata.py
================================================
OPENAI_IMAGENET_TEMPLATES = (
lambda c: f'a bad photo of a {c}.',
lambda c: f'a photo of many {c}.',
lambda c: f'a sculpture of a {c}.',
lambda c: f'a photo of the hard to see {c}.',
lambda c: f'a low resolution photo of the {c}.',
lambda c: f'a rendering of a {c}.',
lambda c: f'graffiti of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a cropped photo of the {c}.',
lambda c: f'a tattoo of a {c}.',
lambda c: f'the embroidered {c}.',
lambda c: f'a photo of a hard to see {c}.',
lambda c: f'a bright photo of a {c}.',
lambda c: f'a photo of a clean {c}.',
lambda c: f'a photo of a dirty {c}.',
lambda c: f'a dark photo of the {c}.',
lambda c: f'a drawing of a {c}.',
lambda c: f'a photo of my {c}.',
lambda c: f'the plastic {c}.',
lambda c: f'a photo of the cool {c}.',
lambda c: f'a close-up photo of a {c}.',
lambda c: f'a black and white photo of the {c}.',
lambda c: f'a painting of the {c}.',
lambda c: f'a painting of a {c}.',
lambda c: f'a pixelated photo of the {c}.',
lambda c: f'a sculpture of the {c}.',
lambda c: f'a bright photo of the {c}.',
lambda c: f'a cropped photo of a {c}.',
lambda c: f'a plastic {c}.',
lambda c: f'a photo of the dirty {c}.',
lambda c: f'a jpeg corrupted photo of a {c}.',
lambda c: f'a blurry photo of the {c}.',
lambda c: f'a photo of the {c}.',
lambda c: f'a good photo of the {c}.',
lambda c: f'a rendering of the {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'a photo of one {c}.',
lambda c: f'a doodle of a {c}.',
lambda c: f'a close-up photo of the {c}.',
lambda c: f'a photo of a {c}.',
lambda c: f'the origami {c}.',
lambda c: f'the {c} in a video game.',
lambda c: f'a sketch of a {c}.',
lambda c: f'a doodle of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a low resolution photo of a {c}.',
lambda c: f'the toy {c}.',
lambda c: f'a rendition of the {c}.',
lambda c: f'a photo of the clean {c}.',
lambda c: f'a photo of a large {c}.',
lambda c: f'a rendition of a {c}.',
lambda c: f'a photo of a nice {c}.',
lambda c: f'a photo of a weird {c}.',
lambda c: f'a blurry photo of a {c}.',
lambda c: f'a cartoon {c}.',
lambda c: f'art of a {c}.',
lambda c: f'a sketch of the {c}.',
lambda c: f'a embroidered {c}.',
lambda c: f'a pixelated photo of a {c}.',
lambda c: f'itap of the {c}.',
lambda c: f'a jpeg corrupted photo of the {c}.',
lambda c: f'a good photo of a {c}.',
lambda c: f'a plushie {c}.',
lambda c: f'a photo of the nice {c}.',
lambda c: f'a photo of the small {c}.',
lambda c: f'a photo of the weird {c}.',
lambda c: f'the cartoon {c}.',
lambda c: f'art of the {c}.',
lambda c: f'a drawing of the {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a black and white photo of a {c}.',
lambda c: f'the plushie {c}.',
lambda c: f'a dark photo of a {c}.',
lambda c: f'itap of a {c}.',
lambda c: f'graffiti of the {c}.',
lambda c: f'a toy {c}.',
lambda c: f'itap of my {c}.',
lambda c: f'a photo of a cool {c}.',
lambda c: f'a photo of a small {c}.',
lambda c: f'a tattoo of the {c}.',
)
# a much smaller subset of above prompts
# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
SIMPLE_IMAGENET_TEMPLATES = (
lambda c: f'itap of a {c}.',
lambda c: f'a bad photo of the {c}.',
lambda c: f'a origami {c}.',
lambda c: f'a photo of the large {c}.',
lambda c: f'a {c} in a video game.',
lambda c: f'art of the {c}.',
lambda c: f'a photo of the small {c}.',
)
IMAGENET_CLASSNAMES = (
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray",
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco",
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper",
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander",
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog",
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin",
"box turtle", "banded gecko", "green iguana", "Carolina anole",
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard",
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile",
"American alligator", "triceratops", "worm snake", "ring-necked snake",
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake",
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra",
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake",
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider",
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider",
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl",
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet",
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck",
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby",
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch",
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab",
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab",
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron",
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot",
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher",
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion",
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel",
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle",
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound",
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound",
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound",
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier",
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier",
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier",
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier",
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer",
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier",
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier",
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever",
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla",
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel",
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel",
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard",
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie",
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann",
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog",
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff",
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky",
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog",
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon",
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle",
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf",
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox",
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat",
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger",
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose",
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle",
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper",
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper",
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly",
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly",
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit",
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse",
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison",
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)",
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat",
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan",
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque",
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin",
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey",
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda",
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish",
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown",
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle",
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo",
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel",
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel",
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)",
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini",
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet",
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra",
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest",
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe",
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton",
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran",
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw",
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking",
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker",
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard",
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot",
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed",
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer",
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table",
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig",
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar",
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder",
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute",
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed",
"freight car", "French horn", "frying pan", "fur coat", "garbage truck",
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola",
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine",
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer",
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet",
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar",
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep",
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat",
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library",
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion",
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag",
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask",
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone",
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile",
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor",
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa",
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail",
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina",
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart",
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush",
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench",
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case",
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube",
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball",
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag",
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho",
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug",
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill",
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel",
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator",
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser",
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal",
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard",
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store",
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap",
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door",
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock",
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater",
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight",
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf",
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa",
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge",
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe",
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball",
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof",
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store",
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod",
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard",
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling",
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball",
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink",
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle",
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing",
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website",
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu",
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette",
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli",
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber",
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange",
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate",
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito",
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef",
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player",
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom",
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"
)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
[project]
name = "inf-cl"
version = "1.2"
authors = [
{name = "Zesen Cheng", email = "cyanlaser@stu.pku.edu.cn"},
{name = "Hang Zhang"},
{name = "Kehan Li"},
{name = "Xin Li"},
]
description = "A highly memory-efficient contrastive loss."
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
classifiers = [
'Development Status :: 4 - Beta',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
]
dependencies = [
'numpy',
'triton>=2.2.0',
]
[project.urls]
Homepage = "https://github.com/clownrat6/Inf-CLIP/inf_cl"
Issues = "https://github.com/clownrat6/Inf-CLIP/issues"
[tool.pdm.build]
excludes = ["./.git"]
package-dir = "."
includes = ["./inf_cl"]
================================================
FILE: requirements.txt
================================================
--extra-index-url https://download.pytorch.org/whl/cu118
# basic dependencies
torch==2.2.0
torchvision==0.17.0
numpy==1.24.4
timm
# data processing
webdataset
pandas
ftfy
regex
braceexpand
# The newest pillow fix this bug: "UserWarning: image file could not be identified because WEBP support not installed"
pillow==10.4.0 # Refer to this issue: https://github.com/ContinuumIO/anaconda-issues/issues/10737
# logging tools
tensorboard
tensorboardX
================================================
FILE: scripts/benchmarks_eval.sh
================================================
clip_benchmark eval \
--model LiT-B-16 \
--pretrained work_dirs/epoch_8.pt \
--dataset datasets/imagenet.txt \
--recall_k 1 5 10 \
--dataset_root datasets/clip-benchmark/wds_{dataset_cleaned} \
--output "benchmark_{dataset}_{pretrained}_{model}_{language}_{task}.json"
================================================
FILE: scripts/cc12m/clip_vit-b-32_bs32k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=32768
LOCAL_BATCH_SIZE=512
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=40
TRAIN_NUM_SAMPLES=10445970
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=clip_cc12m
RUN_NAME=vit-b-32_bs32k_e40
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model ViT-B-32 \
--train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/IMAGE/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--lr 5e-4 \
--beta1 0.9 \
--beta2 0.98 \
--eps 1.0e-8 \
--wd 0.5 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 5 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
================================================
FILE: scripts/cc12m/lit_vit-b-16_bs32k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=32768
LOCAL_BATCH_SIZE=1024
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=20
TRAIN_NUM_SAMPLES=10445970
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_cc12m
RUN_NAME=lit-b-16_bs32k_e20
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-B-16 \
--train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/cc12m/lit_vit-b-32_bs32k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=32768
LOCAL_BATCH_SIZE=1024
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=20
TRAIN_NUM_SAMPLES=10445970
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_cc12m
RUN_NAME=lit-b-32_bs32k_e20
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-B-32 \
--train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/cc3m/clip_r50_bs4k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=4096
LOCAL_BATCH_SIZE=256
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=40
TRAIN_NUM_SAMPLES=3018714
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=clip_cc3m
RUN_NAME=r50_bs4k_e40
DATA_DIR=/mnt/damovl/MEDIA
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model RN50 \
--train-data ${DATA_DIR}'/cc3m/{0000..0301}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--lr 5e-4 \
--beta1 0.9 \
--beta2 0.98 \
--eps 1.0e-8 \
--wd 0.5 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 5 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/cc3m/clip_vit-b-32_bs16k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=16384
LOCAL_BATCH_SIZE=256
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=40
TRAIN_NUM_SAMPLES=3018714
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=clip_cc3m
RUN_NAME=vit-b-32_bs16k_e40
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model ViT-B-32 \
--train-data ${DATA_DIR}'/cc3m/cc3m-train-{0000..0575}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--lr 5e-4 \
--beta1 0.9 \
--beta2 0.98 \
--eps 1.0e-8 \
--wd 0.5 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 5 \
--log_dir $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/cc3m/lit_vit-b-32_bs16k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=16384
LOCAL_BATCH_SIZE=256
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=20
TRAIN_NUM_SAMPLES=3018714
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_cc3m
RUN_NAME=lit-b-32_bs16k_e20
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-B-32 \
--train-data ${DATA_DIR}'/cc3m/{0000..1044}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 32 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/imagenet_eval.sh
================================================
torchrun --nproc_per_node 1 \
-m inf_cl_train.main \
--imagenet-val datasets/imagenet-1k/val \
--model ViT-B-16 \
--pretrained openai \
--workers 64 \
================================================
FILE: scripts/laion400m/clip_vit-b-32_bs256k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=262144
LOCAL_BATCH_SIZE=512
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=8
TRAIN_NUM_SAMPLES=280321756
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=clip_laion400m
RUN_NAME=vit-b-32_bs256k_e8
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model ViT-B-32 \
--train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--lr 5e-4 \
--beta1 0.9 \
--beta2 0.98 \
--eps 1.0e-8 \
--wd 0.5 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
================================================
FILE: scripts/laion400m/lit_vit-b-16_bs256k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=262144
LOCAL_BATCH_SIZE=512
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=8
TRAIN_NUM_SAMPLES=280321756
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_laion400m
RUN_NAME=lit-b-16_bs256k_e8
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-B-16 \
--train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/laion400m/lit_vit-b-32_bs256k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=262144
LOCAL_BATCH_SIZE=512
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=8
TRAIN_NUM_SAMPLES=280321756
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_laion400m
RUN_NAME=lit-b-32_bs256k_e8
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-B-32 \
--train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: scripts/laion400m/lit_vit-l-16_bs256k.sh
================================================
# Environment Variables
ARG_WORLD_SIZE=${1:-1}
ARG_NPROC_PER_NODE=${2:-8}
ARG_MASTER_ADDR="127.0.0.1"
ARG_MASTER_PORT=16666
ARG_RANK=${3:-0}
# Multiple conditions
if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then
WORLD_SIZE=$ARG_WORLD_SIZE
NPROC_PER_NODE=$ARG_NPROC_PER_NODE
fi
if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then
MASTER_ADDR=$ARG_MASTER_ADDR
MASTER_PORT=$ARG_MASTER_PORT
RANK=$ARG_RANK
fi
echo "WORLD_SIZE: $WORLD_SIZE"
echo "NPROC_PER_NODE: $NPROC_PER_NODE"
# Training Arguments
GLOBAL_BATCH_SIZE=262144
LOCAL_BATCH_SIZE=512
ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]
EPOCHS=8
TRAIN_NUM_SAMPLES=280321756
WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]
echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS"
# Log Arguments
export TRANSFORMERS_OFFLINE=1
export WANDB_PROJECT=lit_laion400m
RUN_NAME=lit-l-16_bs256k_e8
DATA_DIR=datasets
OUTP_DIR=work_dirs
torchrun --nnodes $WORLD_SIZE \
--nproc_per_node $NPROC_PER_NODE \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--node_rank $RANK \
-m inf_clip.train.main \
--model LiT-L-16 \
--train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \
--train-num-samples $TRAIN_NUM_SAMPLES \
--aug-cfg scale='(0.08, 1.0)'\
--dataset-type webdataset \
--imagenet-val ${DATA_DIR}/imagenet-1k/val \
--epochs $EPOCHS \
--warmup $WARMUP_STEPS \
--batch-size $LOCAL_BATCH_SIZE \
--accum-freq $ACCUMULATION_STEPS \
--optim adafactor \
--lr 1e-3 \
--beta1 0.9 \
--beta2 0.95 \
--eps 1.0e-8 \
--wd 1e-4 \
--grad-clip-norm 1.0 \
--workers 16 \
--precision amp \
--infloss \
--log-every-n-steps 1 \
--logs $OUTP_DIR/$WANDB_PROJECT \
--name $RUN_NAME \
--save-frequency 1 \
--zeroshot-frequency 1 \
--report-to tensorboard \
--resume latest \
================================================
FILE: tests/example.py
================================================
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
from inf_cl import cal_inf_loss
def create_cl_tensors(rank, world_size):
# Parameters
dtype = torch.float32
num_heads = 3 # Number of attention heads
seq_length_q = 32768 # Sequence length
seq_length_k = 32768
d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128)
# Randomly initialize inputs
q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)
q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
l = l.requires_grad_() # Logit scale
return q, k, l
if __name__ == "__main__":
# Assume that the distributed environment has been initialized
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
# Exampled by Image-Text Contrastive Learning, q is the global image features,
# k is the text features, and l is the logit scale.
q, k, l = create_cl_tensors(rank, world_size)
# labels are diagonal elements by default.
# labels = torch.arange(q.shape[0])
loss = cal_inf_loss(q, k, scale=l.exp())
print(loss)