Repository: DAMO-NLP-SG/Inf-CLIP Branch: main Commit: d9f2833b3753 Files: 152 Total size: 471.8 KB Directory structure: gitextract_u47kktxp/ ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── inf_cl/ │ ├── __init__.py │ ├── flash.py │ └── ring.py ├── inf_clip/ │ ├── __init__.py │ ├── constants.py │ ├── factory.py │ ├── model_configs/ │ │ ├── EVA01-g-14-plus.json │ │ ├── EVA01-g-14.json │ │ ├── EVA02-B-16.json │ │ ├── EVA02-E-14-plus.json │ │ ├── EVA02-E-14.json │ │ ├── EVA02-L-14-336.json │ │ ├── EVA02-L-14.json │ │ ├── LiT-B-16.json │ │ ├── LiT-B-32.json │ │ ├── LiT-L-16.json │ │ ├── MobileCLIP-B.json │ │ ├── MobileCLIP-S1.json │ │ ├── MobileCLIP-S2.json │ │ ├── RN101-quickgelu.json │ │ ├── RN101.json │ │ ├── RN50-quickgelu.json │ │ ├── RN50.json │ │ ├── RN50x16.json │ │ ├── RN50x4.json │ │ ├── RN50x64.json │ │ ├── ViT-B-16-SigLIP-256.json │ │ ├── ViT-B-16-SigLIP-384.json │ │ ├── ViT-B-16-SigLIP-512.json │ │ ├── ViT-B-16-SigLIP-i18n-256.json │ │ ├── ViT-B-16-SigLIP.json │ │ ├── ViT-B-16-plus-240.json │ │ ├── ViT-B-16-plus.json │ │ ├── ViT-B-16-quickgelu.json │ │ ├── ViT-B-16.json │ │ ├── ViT-B-32-256.json │ │ ├── ViT-B-32-plus-256.json │ │ ├── ViT-B-32-quickgelu.json │ │ ├── ViT-B-32.json │ │ ├── ViT-H-14-378-quickgelu.json │ │ ├── ViT-H-14-CLIPA-336.json │ │ ├── ViT-H-14-CLIPA.json │ │ ├── ViT-H-14-quickgelu.json │ │ ├── ViT-H-14.json │ │ ├── ViT-H-16.json │ │ ├── ViT-L-14-280.json │ │ ├── ViT-L-14-336.json │ │ ├── ViT-L-14-CLIPA-336.json │ │ ├── ViT-L-14-CLIPA.json │ │ ├── ViT-L-14-quickgelu.json │ │ ├── ViT-L-14.json │ │ ├── ViT-L-16-320.json │ │ ├── ViT-L-16-SigLIP-256.json │ │ ├── ViT-L-16-SigLIP-384.json │ │ ├── ViT-L-16.json │ │ ├── ViT-M-16-alt.json │ │ ├── ViT-M-16.json │ │ ├── ViT-M-32-alt.json │ │ ├── ViT-M-32.json │ │ ├── ViT-S-16-alt.json │ │ ├── ViT-S-16.json │ │ ├── ViT-S-32-alt.json │ │ ├── ViT-S-32.json │ │ ├── ViT-SO400M-14-SigLIP-384.json │ │ ├── ViT-SO400M-14-SigLIP.json │ │ ├── ViT-bigG-14-CLIPA-336.json │ │ ├── ViT-bigG-14-CLIPA.json │ │ ├── ViT-bigG-14.json │ │ ├── ViT-e-14.json │ │ ├── ViT-g-14.json │ │ ├── ViTamin-B-LTT.json │ │ ├── ViTamin-B.json │ │ ├── ViTamin-L-256.json │ │ ├── ViTamin-L-336.json │ │ ├── ViTamin-L.json │ │ ├── ViTamin-L2-256.json │ │ ├── ViTamin-L2-336.json │ │ ├── ViTamin-L2.json │ │ ├── ViTamin-S-LTT.json │ │ ├── ViTamin-S.json │ │ ├── ViTamin-XL-256.json │ │ ├── ViTamin-XL-336.json │ │ ├── ViTamin-XL-384.json │ │ ├── coca_ViT-B-32.json │ │ ├── coca_ViT-L-14.json │ │ ├── coca_base.json │ │ ├── coca_roberta-ViT-B-32.json │ │ ├── convnext_base.json │ │ ├── convnext_base_w.json │ │ ├── convnext_base_w_320.json │ │ ├── convnext_large.json │ │ ├── convnext_large_d.json │ │ ├── convnext_large_d_320.json │ │ ├── convnext_small.json │ │ ├── convnext_tiny.json │ │ ├── convnext_xlarge.json │ │ ├── convnext_xxlarge.json │ │ ├── convnext_xxlarge_320.json │ │ ├── mt5-base-ViT-B-32.json │ │ ├── mt5-xl-ViT-H-14.json │ │ ├── nllb-clip-base-siglip.json │ │ ├── nllb-clip-base.json │ │ ├── nllb-clip-large-siglip.json │ │ ├── nllb-clip-large.json │ │ ├── roberta-ViT-B-32.json │ │ ├── swin_base_patch4_window7_224.json │ │ ├── vit_medium_patch16_gap_256.json │ │ ├── vit_relpos_medium_patch16_cls_224.json │ │ ├── xlm-roberta-base-ViT-B-32.json │ │ └── xlm-roberta-large-ViT-H-14.json │ ├── models/ │ │ ├── clip_arch.py │ │ ├── coca_arch.py │ │ ├── hf_configs.py │ │ ├── hf_model.py │ │ ├── lit_arch.py │ │ ├── loss.py │ │ ├── modified_resnet.py │ │ ├── pos_embed.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ └── transformer.py │ ├── openai.py │ ├── pretrained.py │ ├── train/ │ │ ├── data.py │ │ ├── engine.py │ │ ├── main.py │ │ ├── optims.py │ │ ├── params.py │ │ └── utils.py │ ├── utils.py │ ├── zero_shot_classifier.py │ └── zero_shot_metadata.py ├── pyproject.toml ├── requirements.txt ├── scripts/ │ ├── benchmarks_eval.sh │ ├── cc12m/ │ │ ├── clip_vit-b-32_bs32k.sh │ │ ├── lit_vit-b-16_bs32k.sh │ │ └── lit_vit-b-32_bs32k.sh │ ├── cc3m/ │ │ ├── clip_r50_bs4k.sh │ │ ├── clip_vit-b-32_bs16k.sh │ │ └── lit_vit-b-32_bs16k.sh │ ├── imagenet_eval.sh │ └── laion400m/ │ ├── clip_vit-b-32_bs256k.sh │ ├── lit_vit-b-16_bs256k.sh │ ├── lit_vit-b-32_bs256k.sh │ └── lit_vit-l-16_bs256k.sh └── tests/ └── example.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ *.py linguist-language=python *.ipynb linguist-documentation ================================================ FILE: .gitignore ================================================ **/logs/ **/wandb/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ sync.sh gpu1sync.sh .idea *.pdf **/._* **/*DS_* **.jsonl src/sbatch src/misc .vscode src/debug core.* # Allow !src/evaluation/misc/results_dbs/* # log dirs /work_dirs*/ /datasets/ # oss logs /.ossutil* /ossutil* ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss

If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏
[![arXiv](https://img.shields.io/badge/Arxiv-2410.17243-AD1C18.svg?logo=arXiv)](https://arxiv.org/abs/2410.17243) [![hf_paper](https://img.shields.io/badge/🤗-Paper%20In%20HF-red.svg)](https://huggingface.co/papers/2410.17243) [![PyPI](https://img.shields.io/badge/PyPI-Inf--CL-9C276A.svg)](https://pypi.org/project/inf-cl)
[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/LICENSE) [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FDAMO-NLP-SG%2FInf-CLIP&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com) [![GitHub issues](https://img.shields.io/github/issues/DAMO-NLP-SG/Inf-CLIP?color=critical&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aopen+is%3Aissue) [![GitHub closed issues](https://img.shields.io/github/issues-closed/DAMO-NLP-SG/Inf-CLIP?color=success&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aissue+is%3Aclosed)
[![zhihu](https://img.shields.io/badge/-知乎-000000?logo=zhihu&logoColor=0084FF)](https://zhuanlan.zhihu.com/p/1681887214) [![Twitter](https://img.shields.io/badge/-Twitter-black?logo=twitter&logoColor=1D9BF0)](https://x.com/lixin4ever/status/1849669129613226457)
💡 Some other multimodal foundation model projects from our team may interest you ✨.

> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922)
> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VCD) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VCD.svg?style=social)](https://github.com/DAMO-NLP-SG/VCD) [![arXiv](https://img.shields.io/badge/Arxiv-2311.16922-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.16922)
> [**VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs**](https://github.com/DAMO-NLP-SG/VideoLLaMA2)
> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA2.svg?style=social)](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [![arXiv](https://img.shields.io/badge/Arxiv-2406.07476-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.07476)
> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787)
> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing
[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/CMM) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/CMM.svg?style=social)](https://github.com/DAMO-NLP-SG/CMM) [![arXiv](https://img.shields.io/badge/Arxiv-2410.12787-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.12787)

## 📰 News * **[2024.10.18]** Release training and evaluation codes of Inf-CLIP.
## 🛠️ Requirements and Installation Basic Dependencies: * Python >= 3.8 * Pytorch >= 2.0.0 * CUDA Version >= 11.8 [Remote] Install Inf-CL: ```bash # remote installing pip install inf_cl -i https://pypi.org/simple ``` [Local] Install Inf-CL: ```bash pip install -e . ``` Install required packages: ```bash git clone https://github.com/DAMO-NLP-SG/Inf-CLIP cd Inf-CLIP pip install -r requirements.txt ``` ## ⭐ Features `inf_cl` is the triton implementation of Inf-CL loss: * [x] [Ring-CL (inf_cl/ring.py#L238)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L238) * [x] [Inf-CL (inf_cl/ring.py#L251)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L251) `inf_clip` is the CLIP training codebase with Inf-CL loss and other training features: - [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](https://github.com/DAMO-NLP-SG/Inf-CLIP/inf_clip_train/train.py#L180) - [x] [Gradient Cache (inf_clip/train/train.py#L292)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip_train/train.py#L292) ## 🔑 Usage A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting: ``` torchrun --nproc_per_node 2 tests/example.py ``` ```python import torch import torch.nn.functional as F import torch.distributed as dist import numpy as np from inf_cl import cal_inf_loss def create_cl_tensors(rank, world_size): # Parameters dtype = torch.float32 num_heads = 3 # Number of attention heads seq_length_q = 32768 # Sequence length seq_length_k = 32768 d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) # Randomly initialize inputs q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}") k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}") l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07) q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key l = l.requires_grad_() # Logit scale return q, k, l if __name__ == "__main__": # Assume that the distributed environment has been initialized dist.init_process_group("nccl") rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(rank) # Exampled by Image-Text Contrastive Learning, q is the global image features, # k is the text features, and l is the logit scale. q, k, l = create_cl_tensors(rank, world_size) # labels are diagonal elements by default. # labels = torch.arange(q.shape[0]) loss = cal_inf_loss(q, k, scale=l.exp()) print(loss) ``` ## 🚀 Main Results ### Memory Cost

\* denotes adopting "data offload" strategy. ### Max Supported Batch Size

### Speed

### Batch Size Scaling

Training with larger data scale needs larger batch size. ## 🗝️ Training & Evaluation ### Quick Start To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks. 1. Training Data Structure: ```bash Inf-CLIP ├── datasets │ ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md | | ├── 0000.tar | | ├── 0001.tar | | ├── ... | | └── 0301.tar │ ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md | | ├── 0000.tar | | ├── 0001.tar | | ├── ... | | └── 1044.tar │ ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md | | ├── 00000.tar | | ├── 00001.tar | | ├── ... | | └── 41407.tar ``` 2. Command: ```bash bash scripts/cc3m/lit_vit-b-32_bs16k.sh bash scripts/cc12m/lit_vit-b-32_bs32k.sh bash scripts/laion400m/lit_vit-b-32_bs256k.sh ``` 3. Evaluation Data Structure: ```bash Inf-CLIP ├── datasets │ ├── imagenet-1k/ # download val_images.tar.gz of imagenet from https://huggingface.co/datasets/ILSVRC/imagenet-1k/tree/main/data | | └── val/ # python datasets/reformat_imagenet.py | | | ├── n01440764 | | | ├── n01443537 | | | ├── ... | | | └── n15075141 │ ├── clip-benchmark/ # bash datasets/benchmarks_download.sh | | ├── wds_mscoco_captions | | ├── wds_flickr8k | | ├── wds_flickr30k | | ├── wds_imagenet1k | | ├── wds_imagenetv2 | | ├── wds_imagenet_sketch | | ├── wds_imagenet-a | | ├── wds_imagenet-r | | ├── wds_imagenet-o | | └── wds_objectnet ``` 4. Command: ```bash # imagenet evaluation bash scripts/imagenet_eval.sh # overall evaluation bash scripts/benchmarks_eval.sh ``` ## 📑 Citation If you find Inf-CLIP useful for your research and applications, please cite using this BibTeX: ```bibtex @article{damovl2024infcl, title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss}, author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing}, journal={arXiv preprint arXiv:2410.17243}, year={2024}, url={https://arxiv.org/abs/2410.12787} } ``` ## 👍 Acknowledgement The codebase of Inf-CLIP is adapted from [**OpenCLIP**](https://github.com/mlfoundations/open_clip). We are also grateful for the following projects our Inf-CL arose from: * [**OpenAI CLIP**](https://openai.com/index/clip/), [**img2dataset**](https://github.com/rom1504/img2dataset), [**CLIP-Benchmark**](https://github.com/LAION-AI/CLIP_benchmark). * [**FlashAttention**](https://github.com/Dao-AILab/flash-attention), [**RingAttention**](https://github.com/haoliuhl/ringattention), [**RingFlashAttention**](https://github.com/zhuzilin/ring-flash-attention). ## 🔒 License This project is released under the Apache 2.0 license as found in the LICENSE file. The service is a research preview intended for **non-commercial use ONLY**, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations. ================================================ FILE: inf_cl/__init__.py ================================================ from .flash import cal_flash_loss from .ring import cal_ring_loss, cal_inf_loss ================================================ FILE: inf_cl/flash.py ================================================ import math import torch import torch.nn.functional as F import numpy as np import triton import triton.language as tl @triton.jit def _prob_fwd_kernel( Q, K, LSE, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # start index of sequence length start_m = tl.program_id(0) # initialize offsets ndims = nheads * BLOCK_HEADDIM offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # Initialize pointers to Q, K, V q_ptrs = Q + ndims * offs_m[:, None] k_ptrs = K + ndims * offs_n[:, None] # initialize pointer to m and l lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # loop over k, v and update accumulator end_n = seqlen_k for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) # -- compute qk ---- qk += tl.dot(q, tl.trans(k)) # Trying to combine the two masks seem to make the result wrong m_ij = tl.maximum(tl.max(qk, 1), m_i) p = tl.exp(qk - m_ij[:, None]) # Fix out of bound access p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0) # -- update statistics lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1) m_i = m_ij lse_i = m_i + tl.log(lse_i) # mask out the padded values lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0) tl.store(LSE + offs_m, lse_i) @triton.jit def _dq_prob_bwd_kernel( Q, K, dQ, LSE, dLSE, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" # start index of sequence length start_m = tl.program_id(0) # initialize offsets ndims = nheads * BLOCK_HEADDIM offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # Initialize pointers to Q, K, V q_ptrs = Q + ndims * offs_m[:, None] dq_ptrs = dQ + ndims * offs_m[:, None] k_ptrs = K + ndims * offs_n[:, None] # setting lse lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0) dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0) # loop over k, v and update accumulator end_n = seqlen_k for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) # -- compute qk ---- qk += tl.dot(q, tl.trans(k)) qk_grad = tl.exp(qk - lse[:, None]) qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0) qk_grad = qk_grad * dlse[:, None] qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) # -- compute q grad ---- # NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch # A solution for this problem # Refer to issue: https://github.com/triton-lang/triton/issues/4574 # if allow_tf32: k = tl.inline_asm_elementwise(ASM, "=r, r", [k], dtype=tl.float32, is_pure=True, pack=1) q_grad = tl.dot(qk_grad, k) # Another solution for this problem # Refer to https://github.com/triton-lang/triton/issues/376 # q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False) # -- store dq ---- dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0) tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q) @triton.jit def _dk_prob_bwd_kernel( Q, K, dK, LSE, dLSE, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): ASM: tl.constexpr = "cvt.rna.tf32.f32 $0, $1;" # start index of sequence length start_n = tl.program_id(0) # initialize offsets ndims = nheads * BLOCK_HEADDIM offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N offs_d = tl.arange(0, BLOCK_HEADDIM) # Initialize pointers to Q, K, V q_ptrs = Q + ndims * offs_m[:, None] k_ptrs = K + ndims * offs_n[:, None] dk_ptrs = dK + ndims * offs_n[:, None] # loop over q and update accumulator end_m = seqlen_q for start_m in range(0, end_m, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) # setting lse lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0) dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) # -- compute qk ---- qk += tl.dot(q, tl.trans(k)) qk_grad = tl.exp(qk - lse[:, None]) qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0) qk_grad = qk_grad * dlse[:, None] qk_grad = tl.inline_asm_elementwise(ASM, "=r, r", [qk_grad], dtype=tl.float32, is_pure=True, pack=1) for off_h in range(nheads): offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :] # -- fetch q and k of a single head ---- q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0) k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) # -- compute k grad ---- q = tl.inline_asm_elementwise(ASM, "=r, r", [q], dtype=tl.float32, is_pure=True, pack=1) k_grad = tl.dot(tl.trans(qk_grad), q) # k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32)) # -- store dk ---- dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0) tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k) def _flash_prob_forward(q, k): # shape constraints seqlen_q, nheads, d = q.shape seqlen_k, _, _ = k.shape assert k.shape == (seqlen_k, nheads, d) # assert d <= 128, "FlashAttention only support head dimensions up to 128" assert q.dtype == k.dtype, "All tensors must have the same type" # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((seqlen_q_rounded), device=q.device, dtype=torch.float32) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK_M = 64 BLOCK_N = 64 num_warps = 8 num_stages = 1 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1) _prob_fwd_kernel[grid]( q, k, lse, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) lse = lse[:seqlen_q] return lse def _flash_prob_backward(q, k, lse, dlse): # shape constraints seqlen_q, nheads, d = q.shape seqlen_k, _, _ = k.shape assert k.shape == (seqlen_k, nheads, d) # assert d <= 128, "FlashAttention only support head dimensions up to 128" assert q.dtype == k.dtype, "All tensors must have the same type" # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.zeros_like(k, dtype=torch.float32) q = q.contiguous() k = k.contiguous() dlse = dlse.contiguous() BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK_M = 64 BLOCK_N = 64 num_warps = 8 num_stages = 1 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), 1) _dq_prob_bwd_kernel[grid]( q, k, dq, lse, dlse, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) BLOCK_N = BLOCK_M BLOCK_M = BLOCK_N grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]), 1) _dk_prob_bwd_kernel[grid]( q, k, dk, lse, dlse, nheads, seqlen_q, seqlen_k, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) dq = dq[:seqlen_q] dk = dk[:seqlen_k] return dq, dk class FlashProb(torch.autograd.Function): @staticmethod def forward(ctx, q, k): lse = _flash_prob_forward(q, k) ctx.save_for_backward(q, k, lse) return lse @staticmethod def backward(ctx, dlse): q, k, lse = ctx.saved_tensors dq, dk = _flash_prob_backward(q, k, lse, dlse) return dq, dk def _cal_flash_loss(q, k, labels, head_dim=256): bq = q.shape[0] bk = k.shape[0] # NOTE: logits forward or backward should keep fp32 for better precision q = q.view(bq, -1, head_dim).float() k = k.view(bk, -1, head_dim).float() lse = FlashProb.apply(q, k) numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...]) loss = -numerator + lse return loss def cal_flash_loss(q, k, labels=None, scale=None, head_dim=256): if labels is None: labels = torch.arange(q.shape[0], device=q.device) if scale is None: scale = 1.0 return _cal_flash_loss(scale * q, k, labels, head_dim) if __name__ == '__main__': import time # Parameters num_heads = 3 # Number of attention heads seq_length_q = 32768 # Sequence length seq_length_k = 32768 d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) # Randomly initialize inputs q = torch.rand((seq_length_q, num_heads * d_model), dtype=torch.float32, device="cuda") # Query k = torch.rand((seq_length_k, num_heads * d_model), dtype=torch.float32, device="cuda") # Key l = torch.ones([], device="cuda") * np.log(1 / 0.02); l.requires_grad = True q = F.normalize(q, p=2, dim=-1); q.requires_grad = True k = F.normalize(k, p=2, dim=-1); k.requires_grad = True q1 = q.clone().detach().requires_grad_(True) k1 = k.clone().detach().requires_grad_(True) l1 = l.clone().detach().requires_grad_(True) labels = torch.arange(seq_length_q).cuda() for i in range(1000): # A. torch gradient start = time.time() qk = torch.einsum("md,nd->mn", l.exp() * q, k) loss = F.cross_entropy(qk, labels, reduction="mean") loss.backward() end = time.time() # B. triton gradient start1 = time.time() loss1 = cal_flash_loss(q1, k1, labels, l1.exp()) loss1 = loss1.mean() loss1.backward() end1 = time.time() print("========= Difference =========") print(end - start, end1 - start1, l.grad, l1.grad) print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad))) q.grad = None; k.grad = None; l.grad = None q1.grad = None; k1.grad = None; l1.grad = None ================================================ FILE: inf_cl/ring.py ================================================ import os import math import random import torch import torch.distributed as dist import torch.distributed.nn as dist_nn import torch.nn.functional as F import numpy as np import triton import triton.language as tl from .flash import _flash_prob_forward, _flash_prob_backward, _cal_flash_loss class RingComm: def __init__(self, process_group: dist.ProcessGroup): self._process_group = process_group self._ops = [] self.rank = dist.get_rank(self._process_group) self.world_size = dist.get_world_size(self._process_group) self._reqs = None self.send_rank = (self.rank + 1) % self.world_size self.recv_rank = (self.rank - 1) % self.world_size # print(f'rank: {self.rank}, send_rank: {self.send_rank}, recv_rank: {self.recv_rank}') if process_group is not None: self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) def send_recv(self, to_send, recv_tensor = None): if recv_tensor is None: res = torch.empty_like(to_send) else: res = recv_tensor send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) self._ops.append(send_op) self._ops.append(recv_op) return res def commit(self): if self._reqs is not None: raise RuntimeError("commit called twice") self._reqs = dist.batch_isend_irecv(self._ops) def wait(self): if self._reqs is None: raise RuntimeError("wait called before commit") for req in self._reqs: req.wait() self._reqs = None self._ops = [] class GradientGather(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x @staticmethod def backward(ctx, dx): dist.all_reduce(dx) return dx class RingProb(torch.autograd.Function): @staticmethod def forward(ctx, q, k, group): rank = dist.get_rank() k = k.contiguous() comm = RingComm(group) colle = [q, k] lse = None next_k = None for step in range(comm.world_size): if step + 1 != comm.world_size: next_k: torch.Tensor = comm.send_recv(k) comm.commit() # vanilla lse qk = torch.einsum("mhd,nhd->mn", q, k) block_lse = torch.log(torch.exp(qk).sum(dim=-1)) if step == 0: lse = block_lse else: lse = lse - F.logsigmoid(lse - block_lse) if step + 1 != comm.world_size: comm.wait() k = next_k # this should be out_padded colle.append(lse) ctx.save_for_backward(*colle) ctx.group = group return lse @staticmethod def backward(ctx, dlse): rank = dist.get_rank() q, k, lse = ctx.saved_tensors k_comm = RingComm(ctx.group) d_k_comm = RingComm(ctx.group) dq, dk = None, None next_dk = None block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device) block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device) next_dk, next_k = None, None for step in range(k_comm.world_size): if step + 1 != k_comm.world_size: next_k = k_comm.send_recv(k) k_comm.commit() # vanilla gradient calculation qk = torch.einsum("mhd,nhd->mn", q, k) qk_grad = torch.exp(qk - lse[:, None]).float() qk_grad = qk_grad * dlse[:, None] block_dq_buffer = torch.einsum("mn,nhd->mhd", qk_grad, k.float()) block_dk_buffer = torch.einsum("nm,mhd->nhd", qk_grad.T, q.float()) if step == 0: dq = block_dq_buffer dk = block_dk_buffer else: dq += block_dq_buffer d_k_comm.wait() dk = block_dk_buffer + next_dk if step + 1 != k_comm.world_size: k_comm.wait() k = next_k next_dk = d_k_comm.send_recv(dk) d_k_comm.commit() d_k_comm.wait() return dq, next_dk, None class InfProb(torch.autograd.Function): @staticmethod def forward(ctx, q, k, group): rank = dist.get_rank() k = k.contiguous() comm = RingComm(group) colle = [q, k] lse = None next_k = None for step in range(comm.world_size): if step + 1 != comm.world_size: next_k: torch.Tensor = comm.send_recv(k) comm.commit() # flash lse block_lse = _flash_prob_forward(q, k) if step == 0: lse = block_lse else: lse = lse - F.logsigmoid(lse - block_lse) if step + 1 != comm.world_size: comm.wait() k = next_k # this should be out_padded colle.append(lse) ctx.save_for_backward(*colle) ctx.group = group return lse @staticmethod def backward(ctx, dlse): rank = dist.get_rank() q, k, lse = ctx.saved_tensors k_comm = RingComm(ctx.group) d_k_comm = RingComm(ctx.group) dq, dk = None, None next_dk = None block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device) block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device) next_dk, next_k = None, None for step in range(k_comm.world_size): if step + 1 != k_comm.world_size: next_k = k_comm.send_recv(k) k_comm.commit() # flash gradient calculation block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse) if step == 0: dq = block_dq_buffer dk = block_dk_buffer else: dq += block_dq_buffer d_k_comm.wait() dk = block_dk_buffer + next_dk if step + 1 != k_comm.world_size: k_comm.wait() k = next_k next_dk = d_k_comm.send_recv(dk) d_k_comm.commit() d_k_comm.wait() return dq, next_dk, None def set_seed(rank, seed=42): seed = rank + seed random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def _cal_ring_loss(q, k, labels, head_dim=256): bq = q.shape[0] bk = k.shape[0] q = q.view(bq, -1, head_dim).float() k = k.view(bk, -1, head_dim).float() lse = RingProb.apply(q, k, None) numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...]) loss = -numerator + lse return loss def _cal_inf_loss(q, k, labels, head_dim=256): bq = q.shape[0] bk = k.shape[0] q = q.view(bq, -1, head_dim).float() k = k.view(bk, -1, head_dim).float() lse = InfProb.apply(q, k, None) numerator = torch.einsum("mhd,mhd->m", q, k[labels, ...]) loss = -numerator + lse return loss def cal_ring_loss(q, k, labels=None, scale=None, head_dim=256): """The triton implementation of the ring-cl. Args: q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D]. k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D]. labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None. scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None. head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256. """ if labels is None: labels = torch.arange(q.shape[0]).to(q.device) if scale is None: scale = 1.0 else: scale = GradientGather.apply(scale) if torch.distributed.is_initialized(): return _cal_ring_loss(scale * q, k, labels, head_dim).mean() else: return _cal_flash_loss(scale * q, k, labels, head_dim).mean() def cal_inf_loss(q, k, labels=None, scale=None, head_dim=256): """The triton implementation of the inf-cl. Args: q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D]. k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D]. labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None. scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None. head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256. """ if labels is None: labels = torch.arange(q.shape[0]).to(q.device) if scale is None: scale = 1.0 else: scale = GradientGather.apply(scale) if torch.distributed.is_initialized(): return _cal_inf_loss(scale * q, k, labels, head_dim).mean() else: return _cal_flash_loss(scale * q, k, labels, head_dim).mean() if __name__ == "__main__": import time dist.init_process_group("nccl") rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(f'cuda:{os.environ["LOCAL_RANK"]}') # Parameters dtype = torch.float32 num_heads = 3 # Number of attention heads seq_length_q = 32768 # Sequence length seq_length_k = 32768 d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) # Randomly initialize inputs q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda") k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda") l = torch.ones([], dtype=dtype, device="cuda") * np.log(1 / 0.07); l = l.requires_grad_() # Logit scale q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key q1 = q.clone().detach().requires_grad_() k1 = k.clone().detach().requires_grad_() l1 = l.clone().detach().requires_grad_() for i in range(1000): # A. local torch gradient start = time.time() # A.1. gather q, k gathered_q = [torch.zeros_like(q) for _ in range(world_size)] gathered_k = [torch.zeros_like(k) for _ in range(world_size)] dist.all_gather(gathered_q, q) dist.all_gather(gathered_k, k) gathered_q[rank] = q gathered_k[rank] = k all_q = torch.cat(gathered_q, dim=0) all_k = torch.cat(gathered_k, dim=0) # A.2. calculating qk logits qk = torch.einsum("md,nd->mn", l.exp() * all_q, all_k) kq = qk.T _labels = torch.arange(seq_length_q).to(q.device) # A.3. calculating loss loss_i2t = F.cross_entropy(qk, _labels, reduction="mean") loss_t2i = F.cross_entropy(kq, _labels, reduction="mean") # A.4. scaling loss to normal value scale_factor = (all_q.shape[0] / q.shape[0]) loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor loss.backward() show_loss = loss.detach().clone() dist.all_reduce(show_loss) show_loss = show_loss / (world_size * scale_factor) end = time.time() dist.barrier() # B. triton implementation start1 = time.time() # labels = torch.arange(seq_length_q // world_size).to(q.device) loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp()) loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp()) loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5 loss1.backward() end1 = time.time() dist.barrier() if rank == 0: print(rank, end - start, end1 - start1, loss, show_loss, loss1) print(l.grad, l1.grad, torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad))) q.grad = None; k.grad = None; l.grad = None q1.grad = None; k1.grad = None; l1.grad = None ================================================ FILE: inf_clip/__init__.py ================================================ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss from .factory import list_models, add_model_config, get_model_config, load_checkpoint from .openai import load_openai_model, list_openai_models from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained from .models.tokenizer import SimpleTokenizer, tokenize, decode from .models.transform import image_transform, AugmentationCfg from .models.coca_arch import CoCa from .models.clip_arch import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg from .models.lit_arch import LiT from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES ================================================ FILE: inf_clip/constants.py ================================================ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) INCEPTION_MEAN = (0.5, 0.5, 0.5) INCEPTION_STD = (0.5, 0.5, 0.5) ================================================ FILE: inf_clip/factory.py ================================================ import json import logging import os import re from copy import deepcopy from dataclasses import asdict from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf, convert_state_dict from .models.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH from .models.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs from .models.clip_arch import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .models.coca_arch import CoCa from .models.lit_arch import LiT from .models.loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, FlashClipLoss, RingClipLoss, InfClipLoss, DiscoClipLoss HF_HUB_PREFIX = 'hf-hub:' _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def _rescan_model_configs(): global _MODEL_CONFIGS config_ext = ('.json',) config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_file() and config_path.suffix in config_ext: config_files.append(config_path) elif config_path.is_dir(): for ext in config_ext: config_files.extend(config_path.glob(f'*{ext}')) for cf in config_files: with open(cf, 'r') as f: model_cfg = json.load(f) if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): _MODEL_CONFIGS[cf.stem] = model_cfg _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} _rescan_model_configs() # initial populate of model config registry def list_models(): """ enumerate available model architectures based on config files """ return list(_MODEL_CONFIGS.keys()) def add_model_config(path): """ add model config path or file and update registry """ if not isinstance(path, Path): path = Path(path) _MODEL_CONFIG_PATHS.append(path) _rescan_model_configs() def get_model_config(model_name): if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None def _get_hf_config(model_id, cache_dir=None): config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config def get_tokenizer( model_name: str = '', context_length: Optional[int] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): model_name = model_name[len(HF_HUB_PREFIX):] try: config = _get_hf_config(model_name)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, context_length=context_length or DEFAULT_CONTEXT_LENGTH, **kwargs, ) return tokenizer else: config = get_model_config(model_name) assert config is not None, f"No valid model config found for {model_name}." text_config = config.get('text_cfg', {}) if 'tokenizer_kwargs' in text_config: tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) else: tokenizer_kwargs = kwargs if context_length is None: context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) if 'hf_tokenizer_name' in text_config: tokenizer = HFTokenizer( text_config['hf_tokenizer_name'], context_length=context_length, **tokenizer_kwargs, ) else: tokenizer = SimpleTokenizer( context_length=context_length, **tokenizer_kwargs, ) return tokenizer def load_state_dict(checkpoint_path: str, map_location='cpu'): checkpoint = torch.load(checkpoint_path, map_location=map_location) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif isinstance(checkpoint, torch.jit.ScriptModule): state_dict = checkpoint.state_dict() for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) else: state_dict = checkpoint if next(iter(state_dict.items()))[0].startswith('module'): state_dict = {k[7:]: v for k, v in state_dict.items()} return state_dict def load_checkpoint( model: Union[CLIP, CustomTextCLIP], checkpoint_path: str, strict: bool = True, ): if Path(checkpoint_path).suffix in ('.npz', '.npy'): # Separate path loading numpy big_vision (SigLIP) weights from open_clip.pretrained import load_big_vision_weights load_big_vision_weights(model, checkpoint_path) return {} state_dict = load_state_dict(checkpoint_path) # Detect & convert 3rd party state_dicts -> open_clip state_dict = convert_state_dict(model, state_dict) # Detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 if 'logit_bias' not in state_dict and model.logit_bias is not None: state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) # Certain text transformers no longer expect position_ids after transformers==4.31 position_id_key = 'text.transformer.embeddings.position_ids' if position_id_key in state_dict and not hasattr(model, position_id_key): del state_dict[position_id_key] resize_pos_embed(state_dict, model) resize_text_pos_embed(state_dict, model) # Finally, load the massaged state_dict into model incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys def create_model( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, force_preprocess_cfg: Optional[Dict[str, Any]] = None, pretrained_image: bool = False, pretrained_hf: bool = False, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, **model_kwargs, ): force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) config = _get_hf_config(model_id, cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights else: model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names checkpoint_path = None model_cfg = None if isinstance(device, str): device = torch.device(device) if pretrained and pretrained.lower() == 'openai': logging.info(f'Loading pretrained {model_name} from OpenAI.') model = load_openai_model( model_name, precision=precision, device=device, cache_dir=cache_dir, ) else: model_cfg = model_cfg or get_model_config(model_name) if model_cfg is not None: logging.info(f'Loaded {model_name} model config.') else: logging.error(f'Model config for {model_name} not found; available models {list_models()}.') raise RuntimeError(f'Model config for {model_name} not found.') if force_quick_gelu: # override for use of QuickGELU on non-OpenAI transformer models model_cfg["quick_gelu"] = True if force_patch_dropout is not None and force_patch_dropout != False: # override the default patch dropout value model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout if force_image_size is not None and force_image_size != False: # override model config's image size model_cfg["vision_cfg"]["image_size"] = force_image_size is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) if pretrained_image: if is_timm_model: # pretrained weight loading for timm models set via vision_cfg model_cfg['vision_cfg']['timm_model_pretrained'] = True else: assert False, 'pretrained image towers currently only supported for timm models' # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes cast_dtype = get_cast_dtype(precision) is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) if is_hf_model: # load pretrained weights for HF text model IFF no CLIP weights being loaded # NOTE: disable pretrained_hf arguments. # model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained model_cfg['text_cfg']['hf_model_pretrained'] = model_cfg['text_cfg']['hf_model_pretrained'] # and not pretrained custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) model_arch = model_cfg.pop("arch", "CLIP") if custom_text: if "CLIP" in model_arch: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) elif "LiT" in model_arch: model = LiT(**model_cfg, cast_dtype=cast_dtype) elif "CoCa" in model_arch or "multimodal_cfg" in model_cfg: model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 # manual mixed precision that matches original OpenAI behaviour if is_timm_model: # FIXME this is a bit janky, create timm based model in low-precision and # then cast only LayerNormFp32 instances back to float32 so they don't break. # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) from .transformer import LayerNormFp32 def _convert_ln(m): if isinstance(m, LayerNormFp32): m.weight.data = m.weight.data.to(torch.float32) m.bias.data = m.bias.data.to(torch.float32) model.apply(_convert_ln) else: model.to(device=device) convert_weights_to_lp(model, dtype=dtype) elif precision in ("pure_fp16", "pure_bf16"): dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 model.to(device=device, dtype=dtype) else: model.to(device=device) pretrained_loaded = False if pretrained: checkpoint_path = '' pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) elif os.path.exists(pretrained): checkpoint_path = pretrained if checkpoint_path: logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') load_checkpoint(model, checkpoint_path) else: error_str = ( f'Pretrained weights ({pretrained}) not found for model {model_name}.' f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') logging.warning(error_str) raise RuntimeError(error_str) pretrained_loaded = True elif has_hf_hub_prefix: logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') load_checkpoint(model, checkpoint_path) pretrained_loaded = True if require_pretrained and not pretrained_loaded: # callers of create_model_from_pretrained always expect pretrained weights raise RuntimeError( f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True if jit: model = torch.jit.script(model) # set image preprocessing configuration in model attributes for convenience if getattr(model.visual, 'image_size', None) is not None: # use image_size set on model creation (via config or force_image_size arg) force_preprocess_cfg['size'] = model.visual.image_size set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) return model def create_loss(args): if args.distill_model: return DistillClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) elif "coca" in args.model.lower(): return CoCaLoss( caption_loss_weight=args.coca_caption_loss_weight, clip_loss_weight=args.coca_contrastive_loss_weight, local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) elif args.siglip: assert not args.horovod, "Horovod not currently supported for SigLip" return SigLipLoss( rank=args.rank, world_size=args.world_size, ) elif args.flashloss: return FlashClipLoss( rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) elif args.ringloss: return RingClipLoss( rank=args.rank, world_size=args.world_size, use_horovod=args.horovod ) elif args.infloss: return InfClipLoss( rank=args.rank, world_size=args.world_size, use_horovod=args.horovod ) elif args.discoloss: return DiscoClipLoss( rank=args.rank, world_size=args.world_size, use_horovod=args.horovod ) return ClipLoss( local_loss=args.local_loss, gather_with_grad=args.gather_with_grad, cache_labels=True, rank=args.rank, world_size=args.world_size, use_horovod=args.horovod, ) def create_model_and_transforms( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_patch_dropout: Optional[float] = None, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, image_interpolation: Optional[str] = None, image_resize_mode: Optional[str] = None, # only effective for inference aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, pretrained_image: bool = False, pretrained_hf: bool = False, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_patch_dropout=force_patch_dropout, force_image_size=force_image_size, force_preprocess_cfg=force_preprocess_cfg, pretrained_image=pretrained_image, pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, **model_kwargs, ) pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) preprocess_train = image_transform_v2( pp_cfg, is_train=True, aug_cfg=aug_cfg, ) preprocess_val = image_transform_v2( pp_cfg, is_train=False, ) return model, preprocess_train, preprocess_val def create_model_from_pretrained( model_name: str, pretrained: Optional[str] = None, precision: str = 'fp32', device: Union[str, torch.device] = 'cpu', jit: bool = False, force_quick_gelu: bool = False, force_custom_text: bool = False, force_image_size: Optional[Union[int, Tuple[int, int]]] = None, image_mean: Optional[Tuple[float, ...]] = None, image_std: Optional[Tuple[float, ...]] = None, image_interpolation: Optional[str] = None, image_resize_mode: Optional[str] = None, # only effective for inference return_transform: bool = True, cache_dir: Optional[str] = None, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) model = create_model( model_name, pretrained, precision=precision, device=device, jit=jit, force_quick_gelu=force_quick_gelu, force_custom_text=force_custom_text, force_image_size=force_image_size, force_preprocess_cfg=force_preprocess_cfg, cache_dir=cache_dir, require_pretrained=True, **model_kwargs, ) if not return_transform: return model preprocess = image_transform_v2( PreprocessCfg(**model.visual.preprocess_cfg), is_train=False, ) return model, preprocess ================================================ FILE: inf_clip/model_configs/EVA01-g-14-plus.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "timm_model_name": "eva_giant_patch14_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA01-g-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "timm_model_name": "eva_giant_patch14_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA02-B-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "timm_model_name": "eva02_base_patch16_clip_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA02-E-14-plus.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "timm_model_name": "eva02_enormous_patch14_clip_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 32 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA02-E-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "timm_model_name": "eva02_enormous_patch14_clip_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA02-L-14-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 336, "timm_model_name": "eva02_large_patch14_clip_336", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/EVA02-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "timm_model_name": "eva02_large_patch14_clip_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/LiT-B-16.json ================================================ { "arch": "LiT-B-16", "embed_dim": 768, "vision_cfg": { "image_size": 224, "timm_model_name": "vit_base_patch16_224", "timm_model_pretrained": true, "timm_pool": "token", "timm_proj": "linear" }, "text_cfg": { "hf_tokenizer_name": "bert-base-uncased", "hf_model_name": "bert-base-uncased", "hf_model_pretrained": true, "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/LiT-B-32.json ================================================ { "arch": "LiT-B-32", "embed_dim": 768, "vision_cfg": { "image_size": 224, "timm_model_name": "vit_base_patch32_224", "timm_model_pretrained": true, "timm_pool": "token", "timm_proj": "linear" }, "text_cfg": { "hf_tokenizer_name": "bert-base-uncased", "hf_model_name": "bert-base-uncased", "hf_model_pretrained": true, "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/LiT-L-16.json ================================================ { "arch": "LiT-L-16", "embed_dim": 1024, "vision_cfg": { "image_size": 224, "timm_model_name": "vit_large_patch16_224", "timm_model_pretrained": true, "timm_pool": "token", "timm_proj": "linear" }, "text_cfg": { "hf_tokenizer_name": "bert-large-uncased", "hf_model_name": "bert-large-uncased", "hf_model_pretrained": true, "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/MobileCLIP-B.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vit_base_mci_224", "timm_model_pretrained": false, "timm_pool": "token", "timm_proj": null, "timm_drop": 0.0, "timm_drop_path": 0.0, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "no_causal_mask": false }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/MobileCLIP-S1.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "fastvit_mci1", "timm_model_pretrained": false, "timm_pool": "avg", "timm_proj": null, "timm_drop": 0.0, "timm_drop_path": 0.0, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "no_causal_mask": true }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/MobileCLIP-S2.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "fastvit_mci2", "timm_model_pretrained": false, "timm_pool": "avg", "timm_proj": null, "timm_drop": 0.0, "timm_drop_path": 0.0, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "no_causal_mask": true }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/RN101-quickgelu.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 23, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN101.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 23, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN50-quickgelu.json ================================================ { "embed_dim": 1024, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 6, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN50.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": [ 3, 4, 6, 3 ], "width": 64, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN50x16.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 384, "layers": [ 6, 8, 18, 8 ], "width": 96, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN50x4.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 288, "layers": [ 4, 6, 10, 6 ], "width": 80, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/RN50x64.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 448, "layers": [ 3, 15, 36, 10 ], "width": 128, "patch_size": null }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-SigLIP-256.json ================================================ { "embed_dim": 768, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 256, "timm_model_name": "vit_base_patch16_siglip_256", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 768, "heads": 12, "layers": 12, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-SigLIP-384.json ================================================ { "embed_dim": 768, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 384, "timm_model_name": "vit_base_patch16_siglip_384", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 768, "heads": 12, "layers": 12, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-SigLIP-512.json ================================================ { "embed_dim": 768, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 512, "timm_model_name": "vit_base_patch16_siglip_512", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 768, "heads": 12, "layers": 12, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json ================================================ { "embed_dim": 768, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 256, "timm_model_name": "vit_base_patch16_siglip_256", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 250000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 768, "heads": 12, "layers": 12, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-SigLIP.json ================================================ { "embed_dim": 768, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 224, "timm_model_name": "vit_base_patch16_siglip_224", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 768, "heads": 12, "layers": 12, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-plus-240.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 240, "layers": 12, "width": 896, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-plus.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 224, "layers": 12, "width": 896, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-16-quickgelu.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-32-256.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 256, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-32-plus-256.json ================================================ { "embed_dim": 640, "vision_cfg": { "image_size": 256, "layers": 12, "width": 896, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-32-quickgelu.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-H-14-378-quickgelu.json ================================================ { "embed_dim": 1024, "quick_gelu": true, "vision_cfg": { "image_size": 378, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/ViT-H-14-CLIPA-336.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 336, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 1024, "heads": 16, "layers": 24, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-H-14-CLIPA.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 1024, "heads": 16, "layers": 24, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-H-14-quickgelu.json ================================================ { "embed_dim": 1024, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/ViT-H-16.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/ViT-L-14-280.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 280, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-L-14-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 336, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-L-14-CLIPA-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 336, "layers": 24, "width": 1024, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 768, "heads": 12, "layers": 12, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-L-14-CLIPA.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 768, "heads": 12, "layers": 12, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-L-14-quickgelu.json ================================================ { "embed_dim": 768, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-L-16-320.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 320, "layers": 24, "width": 1024, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-L-16-SigLIP-256.json ================================================ { "embed_dim": 1024, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 256, "timm_model_name": "vit_large_patch16_siglip_256", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 1024, "heads": 16, "layers": 24, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-L-16-SigLIP-384.json ================================================ { "embed_dim": 1024, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 384, "timm_model_name": "vit_large_patch16_siglip_384", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 1024, "heads": 16, "layers": 24, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-L-16.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-M-16-alt.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 16, "ls_init_value": 1e-4 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-M-16.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-M-32-alt.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-M-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 512, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-S-16-alt.json ================================================ { "embed_dim": 256, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 256, "heads": 4, "layers": 10 } } ================================================ FILE: inf_clip/model_configs/ViT-S-16.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 16 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-S-32-alt.json ================================================ { "embed_dim": 256, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 256, "heads": 4, "layers": 10 } } ================================================ FILE: inf_clip/model_configs/ViT-S-32.json ================================================ { "embed_dim": 384, "vision_cfg": { "image_size": 224, "layers": 12, "width": 384, "patch_size": 32 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json ================================================ { "embed_dim": 1152, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 384, "timm_model_name": "vit_so400m_patch14_siglip_384", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 64, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 1152, "heads": 16, "layers": 27, "mlp_ratio": 3.7362, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-SO400M-14-SigLIP.json ================================================ { "embed_dim": 1152, "init_logit_bias": -10, "custom_text": true, "vision_cfg": { "image_size": 224, "timm_model_name": "vit_so400m_patch14_siglip_224", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "context_length": 16, "vocab_size": 32000, "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", "tokenizer_kwargs": { "clean": "canonicalize" }, "width": 1152, "heads": 16, "layers": 27, "mlp_ratio": 3.7362, "no_causal_mask": true, "proj_bias": true, "pool_type": "last", "norm_kwargs":{ "eps": 1e-6 } } } ================================================ FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 336, "layers": 48, "width": 1664, "head_width": 104, "mlp_ratio": 4.9231, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 1280, "heads": 20, "layers": 32, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-bigG-14-CLIPA.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 224, "layers": 48, "width": 1664, "head_width": 104, "mlp_ratio": 4.9231, "patch_size": 14, "no_ln_pre": true, "pool_type": "avg", "final_ln_after_pool": true }, "text_cfg": { "context_length": 32, "vocab_size": 32000, "hf_tokenizer_name": "bert-base-uncased", "tokenizer_kwargs": { "strip_sep_token": true }, "width": 1280, "heads": 20, "layers": 32, "pool_type": "last", "no_causal_mask": true } } ================================================ FILE: inf_clip/model_configs/ViT-bigG-14.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 224, "layers": 48, "width": 1664, "head_width": 104, "mlp_ratio": 4.9231, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 32 } } ================================================ FILE: inf_clip/model_configs/ViT-e-14.json ================================================ { "embed_dim": 1280, "vision_cfg": { "image_size": 224, "layers": 56, "width": 1792, "head_width": 112, "mlp_ratio": 8.5715, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1280, "heads": 20, "layers": 36 } } ================================================ FILE: inf_clip/model_configs/ViT-g-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 40, "width": 1408, "head_width": 88, "mlp_ratio": 4.3637, "patch_size": 14 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/ViTamin-B-LTT.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "vitamin_base_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-B.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vitamin_base_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L-256.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "vitamin_large_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L-336.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "vitamin_large_336", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 336 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "vitamin_large_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L2-256.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "vitamin_large2_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L2-336.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "vitamin_large2_336", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 336 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-L2.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "vitamin_large2_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-S-LTT.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "vitamin_small_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-S.json ================================================ { "embed_dim": 384, "vision_cfg": { "timm_model_name": "vitamin_small_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 384, "heads": 6, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-XL-256.json ================================================ { "embed_dim": 1152, "vision_cfg": { "timm_model_name": "vitamin_xlarge_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1152, "heads": 16, "layers": 27 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-XL-336.json ================================================ { "embed_dim": 1152, "vision_cfg": { "timm_model_name": "vitamin_xlarge_336", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 336 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1152, "heads": 16, "layers": 27 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/ViTamin-XL-384.json ================================================ { "embed_dim": 1152, "vision_cfg": { "timm_model_name": "vitamin_xlarge_384", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1152, "heads": 16, "layers": 27 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/coca_ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32, "attentional_pool": true, "attn_pooler_heads": 8, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "embed_cls": true, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12, "attn_pooler_heads": 8 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/coca_ViT-L-14.json ================================================ { "embed_dim": 768, "vision_cfg": { "image_size": 224, "layers": 24, "width": 1024, "patch_size": 14, "attentional_pool": true, "attn_pooler_heads": 8, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "embed_cls": true, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12, "attn_pooler_heads": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/coca_base.json ================================================ { "embed_dim": 512, "multimodal_cfg": { "width": 768, "context_length": 76, "vocab_size": 64000, "mlp_ratio": 4, "layers": 12, "dim_head": 64, "heads": 12, "n_queries": 256, "attn_pooler_heads": 8 }, "vision_cfg": { "image_size": 288, "layers": 12, "width": 768, "patch_size": 18, "output_tokens": true }, "text_cfg": { "context_length": 76, "vocab_size": 64000, "layers": 12, "heads": 12, "width": 768, "embed_cls": true, "output_tokens": true }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/coca_roberta-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32, "output_tokens": true }, "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", "hf_proj_type": "linear", "width": 768, "output_tokens": true }, "multimodal_cfg": { "context_length": 76, "width": 768, "heads": 8, "layers": 12 }, "custom_text": true } ================================================ FILE: inf_clip/model_configs/convnext_base.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_base_w.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_base_w_320.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "convnext_base", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_large.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_large_d.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "mlp", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 16 } } ================================================ FILE: inf_clip/model_configs/convnext_large_d_320.json ================================================ { "embed_dim": 768, "vision_cfg": { "timm_model_name": "convnext_large", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "mlp", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 768, "heads": 12, "layers": 16 } } ================================================ FILE: inf_clip/model_configs/convnext_small.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "convnext_small", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_tiny.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_tiny", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/convnext_xlarge.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 20 } } ================================================ FILE: inf_clip/model_configs/convnext_xxlarge.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xxlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/convnext_xxlarge_320.json ================================================ { "embed_dim": 1024, "vision_cfg": { "timm_model_name": "convnext_xxlarge", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "timm_drop": 0.0, "timm_drop_path": 0.1, "image_size": 320 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 1024, "heads": 16, "layers": 24 } } ================================================ FILE: inf_clip/model_configs/mt5-base-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "google/mt5-base", "hf_tokenizer_name": "google/mt5-base", "hf_pooler_type": "mean_pooler" } } ================================================ FILE: inf_clip/model_configs/mt5-xl-ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "hf_model_name": "google/mt5-xl", "hf_tokenizer_name": "google/mt5-xl", "hf_pooler_type": "mean_pooler" } } ================================================ FILE: inf_clip/model_configs/nllb-clip-base-siglip.json ================================================ { "embed_dim": 768, "custom_text": true, "init_logit_bias": -10, "vision_cfg": { "image_size": 384, "timm_model_name": "vit_base_patch16_siglip_384", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-600M", "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/nllb-clip-base.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-600M", "hf_tokenizer_name": "facebook/nllb-200-distilled-600M", "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/nllb-clip-large-siglip.json ================================================ { "embed_dim": 1152, "custom_text": true, "init_logit_bias": -10, "vision_cfg": { "image_size": 384, "timm_model_name": "vit_so400m_patch14_siglip_384", "timm_model_pretrained": false, "timm_pool": "map", "timm_proj": "none" }, "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-1.3B", "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/nllb-clip-large.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "hf_model_name": "facebook/nllb-200-distilled-1.3B", "hf_tokenizer_name": "facebook/nllb-200-distilled-1.3B", "hf_proj_type": "linear", "hf_pooler_type": "cls_pooler" } } ================================================ FILE: inf_clip/model_configs/roberta-ViT-B-32.json ================================================ { "embed_dim": 512, "quick_gelu": true, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "roberta-base", "hf_tokenizer_name": "roberta-base", "hf_pooler_type": "mean_pooler" } } ================================================ FILE: inf_clip/model_configs/swin_base_patch4_window7_224.json ================================================ { "embed_dim": 640, "vision_cfg": { "timm_model_name": "swin_base_patch4_window7_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 640, "heads": 10, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/vit_medium_patch16_gap_256.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vit_medium_patch16_gap_256", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 256 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json ================================================ { "embed_dim": 512, "vision_cfg": { "timm_model_name": "vit_relpos_medium_patch16_cls_224", "timm_model_pretrained": false, "timm_pool": "", "timm_proj": "linear", "image_size": 224 }, "text_cfg": { "context_length": 77, "vocab_size": 49408, "width": 512, "heads": 8, "layers": 12 } } ================================================ FILE: inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json ================================================ { "embed_dim": 512, "vision_cfg": { "image_size": 224, "layers": 12, "width": 768, "patch_size": 32 }, "text_cfg": { "hf_model_name": "xlm-roberta-base", "hf_tokenizer_name": "xlm-roberta-base", "hf_pooler_type": "mean_pooler" } } ================================================ FILE: inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json ================================================ { "embed_dim": 1024, "vision_cfg": { "image_size": 224, "layers": 32, "width": 1280, "head_width": 80, "patch_size": 14 }, "text_cfg": { "hf_model_name": "xlm-roberta-large", "hf_tokenizer_name": "xlm-roberta-large", "hf_pooler_type": "mean_pooler" } } ================================================ FILE: inf_clip/models/clip_arch.py ================================================ """ CLIP Model Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import copy import logging import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from functools import partial from .hf_model import HFTextEncoder from .modified_resnet import ModifiedResNet from .timm_model import TimmModel from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ text_global_pool from ..utils import to_2tuple @dataclass class CLIPVisionCfg: layers: Union[Tuple[int, int, int, int], int] = 12 width: int = 768 head_width: int = 64 mlp_ratio: float = 4.0 patch_size: int = 16 image_size: Union[Tuple[int, int], int] = 224 ls_init_value: Optional[float] = None # layer scale initial value patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) attn_pooler_queries: int = 256 # n_queries for attentional pooler attn_pooler_heads: int = 8 # n heads for attentional_pooling no_ln_pre: bool = False # disable pre transformer LayerNorm pos_embed_type: str = 'learnable' final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'tok' output_tokens: bool = False act_kwargs: Optional[dict] = None norm_kwargs: Optional[dict] = None timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') timm_proj_bias: bool = False # enable bias final projection timm_drop: float = 0. # head dropout timm_drop_path: Optional[float] = None # backbone stochastic depth @dataclass class CLIPTextCfg: context_length: int = 77 vocab_size: int = 49408 hf_tokenizer_name: Optional[str] = None tokenizer_kwargs: Optional[dict] = None width: int = 512 heads: int = 8 layers: int = 12 mlp_ratio: float = 4.0 ls_init_value: Optional[float] = None # layer scale initial value embed_cls: bool = False pad_id: int = 0 no_causal_mask: bool = False # disable causal masking final_ln_after_pool: bool = False # apply final LayerNorm after pooling pool_type: str = 'argmax' proj_bias: bool = False output_tokens: bool = False act_kwargs: dict = None norm_kwargs: dict = None # HuggingFace specific text tower config hf_model_name: Optional[str] = None hf_model_pretrained: bool = True hf_proj_type: str = 'mlp' hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models def get_cast_dtype(precision: str): cast_dtype = None if precision == 'bf16': cast_dtype = torch.bfloat16 elif precision == 'fp16': cast_dtype = torch.float16 return cast_dtype def get_input_dtype(precision: str): input_dtype = None if precision in ('bf16', 'pure_bf16'): input_dtype = torch.bfloat16 elif precision in ('fp16', 'pure_fp16'): input_dtype = torch.float16 return input_dtype def _build_vision_tower( embed_dim: int, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None ): if isinstance(vision_cfg, dict): vision_cfg = CLIPVisionCfg(**vision_cfg) # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more # memory efficient in recent PyTorch releases (>= 1.10). # NOTE: timm models always use native GELU regardless of quick_gelu flag. act_layer = QuickGELU if quick_gelu else nn.GELU if vision_cfg.timm_model_name: visual = TimmModel( vision_cfg.timm_model_name, pretrained=vision_cfg.timm_model_pretrained, pool=vision_cfg.timm_pool, proj=vision_cfg.timm_proj, proj_bias=vision_cfg.timm_proj_bias, drop=vision_cfg.timm_drop, drop_path=vision_cfg.timm_drop_path, patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, embed_dim=embed_dim, image_size=vision_cfg.image_size, ) elif isinstance(vision_cfg.layers, (tuple, list)): vision_heads = vision_cfg.width * 32 // vision_cfg.head_width visual = ModifiedResNet( layers=vision_cfg.layers, output_dim=embed_dim, heads=vision_heads, image_size=vision_cfg.image_size, width=vision_cfg.width, ) else: vision_heads = vision_cfg.width // vision_cfg.head_width norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if vision_cfg.norm_kwargs: norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) if vision_cfg.act_kwargs is not None: act_layer = partial(act_layer, **vision_cfg.act_kwargs) visual = VisionTransformer( image_size=vision_cfg.image_size, patch_size=vision_cfg.patch_size, width=vision_cfg.width, layers=vision_cfg.layers, heads=vision_heads, mlp_ratio=vision_cfg.mlp_ratio, ls_init_value=vision_cfg.ls_init_value, patch_dropout=vision_cfg.patch_dropout, attentional_pool=vision_cfg.attentional_pool, attn_pooler_queries=vision_cfg.attn_pooler_queries, attn_pooler_heads=vision_cfg.attn_pooler_heads, pos_embed_type=vision_cfg.pos_embed_type, no_ln_pre=vision_cfg.no_ln_pre, final_ln_after_pool=vision_cfg.final_ln_after_pool, pool_type=vision_cfg.pool_type, output_tokens=vision_cfg.output_tokens, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return visual def _build_text_tower( embed_dim: int, text_cfg: CLIPTextCfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): if isinstance(text_cfg, dict): text_cfg = CLIPTextCfg(**text_cfg) if text_cfg.hf_model_name: text = HFTextEncoder( text_cfg.hf_model_name, output_dim=embed_dim, proj_type=text_cfg.hf_proj_type, pooler_type=text_cfg.hf_pooler_type, pretrained=text_cfg.hf_model_pretrained, output_tokens=text_cfg.output_tokens, ) else: act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm if text_cfg.norm_kwargs: norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) if text_cfg.act_kwargs is not None: act_layer = partial(act_layer, **text_cfg.act_kwargs) text = TextTransformer( context_length=text_cfg.context_length, vocab_size=text_cfg.vocab_size, width=text_cfg.width, heads=text_cfg.heads, layers=text_cfg.layers, mlp_ratio=text_cfg.mlp_ratio, ls_init_value=text_cfg.ls_init_value, output_dim=embed_dim, embed_cls=text_cfg.embed_cls, no_causal_mask=text_cfg.no_causal_mask, pad_id=text_cfg.pad_id, pool_type=text_cfg.pool_type, proj_bias=text_cfg.proj_bias, output_tokens=text_cfg.output_tokens, act_layer=act_layer, norm_layer=norm_layer, ) return text class CLIP(nn.Module): output_dict: torch.jit.Final[bool] arch_type: torch.jit.Final[str] = 'clip' def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.transformer = text.transformer self.context_length = text.context_length self.vocab_size = text.vocab_size self.token_embedding = text.token_embedding self.positional_embedding = text.positional_embedding self.ln_final = text.ln_final self.text_projection = text.text_projection self.text_pool_type = text.pool_type self.register_buffer('attn_mask', text.attn_mask, persistent=False) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.transformer.grad_checkpointing = enable def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] x, _ = text_global_pool(x, text, self.text_pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): x = self.text_projection(x) else: x = x @ self.text_projection return F.normalize(x, dim=-1) if normalize else x def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() class CustomTextCLIP(nn.Module): output_dict: torch.jit.Final[bool] arch_type: torch.jit.Final[str] = 'clip' def __init__( self, embed_dim: int, vision_cfg: CLIPVisionCfg, text_cfg: CLIPTextCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, output_dict: bool = False, ): super().__init__() self.output_dict = output_dict self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) self.context_length = self.text.context_length self.vocab_size = self.text.vocab_size self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): self.text.lock(unlocked_layers, freeze_layer_norm) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) def encode_image(self, image, normalize: bool = False): features = self.visual(image) return F.normalize(features, dim=-1) if normalize else features def encode_text(self, text, normalize: bool = False): features = self.text(text) return F.normalize(features, dim=-1) if normalize else features def get_logits(self, image, text): image_features = self.encode_image(image, normalize=True) text_features = self.encode_text(text, normalize=True) image_logits = self.logit_scale.exp() * image_features @ text_features.T if self.logit_bias is not None: image_logits += self.logit_bias text_logits = image_logits.T return image_logits, text_logits def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if self.output_dict: out_dict = { "image_features": image_features, "text_features": text_features, "logit_scale": self.logit_scale.exp() } if self.logit_bias is not None: out_dict['logit_bias'] = self.logit_bias return out_dict if self.logit_bias is not None: return image_features, text_features, self.logit_scale.exp(), self.logit_bias return image_features, text_features, self.logit_scale.exp() def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): """Convert applicable model parameters to low-precision (bf16 or fp16)""" def _convert_weights(l): if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): l.weight.data = l.weight.data.to(dtype) if l.bias is not None: l.bias.data = l.bias.data.to(dtype) if isinstance(l, (nn.MultiheadAttention, Attention)): for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: tensor = getattr(l, attr) if tensor is not None: tensor.data = tensor.data.to(dtype) if isinstance(l, (CLIP, TextTransformer)): # convert text nn.Parameter projections attr = getattr(l, "text_projection", None) if attr is not None: attr.data = attr.data.to(dtype) if isinstance(l, VisionTransformer): # convert vision nn.Parameter projections attr = getattr(l, "proj", None) if attr is not None: attr.data = attr.data.to(dtype) model.apply(_convert_weights) convert_weights_to_fp16 = convert_weights_to_lp # backwards compat # used to maintain checkpoint compatibility def convert_to_custom_text_state_dict(state_dict: dict): if 'text_projection' in state_dict: # old format state_dict, move text tower -> .text new_state_dict = {} for k, v in state_dict.items(): if any(k.startswith(p) for p in ( 'text_projection', 'positional_embedding', 'token_embedding', 'transformer', 'ln_final', )): k = 'text.' + k new_state_dict[k] = v return new_state_dict return state_dict def build_model_from_openai_state_dict( state_dict: dict, quick_gelu=True, cast_dtype=torch.float16, ): vit = "visual.proj" in state_dict if vit: vision_width = state_dict["visual.conv1.weight"].shape[0] vision_layers = len( [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) image_size = vision_patch_size * grid_size else: counts: list = [ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] vision_layers = tuple(counts) vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) vision_patch_size = None assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] image_size = output_width * 32 embed_dim = state_dict["text_projection"].shape[1] context_length = state_dict["positional_embedding"].shape[0] vocab_size = state_dict["token_embedding.weight"].shape[0] transformer_width = state_dict["ln_final.weight"].shape[0] transformer_heads = transformer_width // 64 transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) vision_cfg = CLIPVisionCfg( layers=vision_layers, width=vision_width, patch_size=vision_patch_size, image_size=image_size, ) text_cfg = CLIPTextCfg( context_length=context_length, vocab_size=vocab_size, width=transformer_width, heads=transformer_heads, layers=transformer_layers, ) model = CLIP( embed_dim, vision_cfg=vision_cfg, text_cfg=text_cfg, quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU cast_dtype=cast_dtype, ) for key in ["input_resolution", "context_length", "vocab_size"]: state_dict.pop(key, None) convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 model.load_state_dict(state_dict) return model.eval() def trace_model(model, batch_size=256, device=torch.device('cpu')): model.eval() image_size = model.visual.image_size example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) model = torch.jit.trace_module( model, inputs=dict( forward=(example_images, example_text), encode_text=(example_text,), encode_image=(example_images,) )) model.visual.image_size = image_size return model def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): # Rescale the grid of position embeddings when loading from state_dict old_pos_embed = state_dict.get('visual.positional_embedding', None) if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): return grid_size = to_2tuple(model.visual.grid_size) extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) new_seq_len = grid_size[0] * grid_size[1] + extra_tokens if new_seq_len == old_pos_embed.shape[0]: return if extra_tokens: pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] else: pos_emb_tok, pos_emb_img = None, old_pos_embed old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) pos_emb_img = F.interpolate( pos_emb_img, size=grid_size, mode=interpolation, antialias=antialias, align_corners=False, ) pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] if pos_emb_tok is not None: new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) else: new_pos_embed = pos_emb_img state_dict['visual.positional_embedding'] = new_pos_embed def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): old_pos_embed = state_dict.get('positional_embedding', None) if old_pos_embed is None: return # FIXME add support for text cls_token model_pos_embed = getattr(model, 'positional_embedding', None) if model_pos_embed is None: model_pos_embed = getattr(model.text, 'positional_embedding', None) old_num_pos = old_pos_embed.shape[0] old_width = old_pos_embed.shape[1] num_pos = model_pos_embed.shape[0] width = model_pos_embed.shape[1] assert old_width == width, 'text pos_embed width changed!' if old_num_pos == num_pos: return logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) old_pos_embed = F.interpolate( old_pos_embed, size=num_pos, mode=interpolation, antialias=antialias, align_corners=False, ) old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] new_pos_embed = old_pos_embed state_dict['positional_embedding'] = new_pos_embed def get_model_preprocess_cfg(model): module = getattr(model, 'visual', model) preprocess_cfg = getattr(module, 'preprocess_cfg', {}) if not preprocess_cfg: # use separate legacy attributes if preprocess_cfg dict not found size = getattr(module, 'image_size') if size is not None: preprocess_cfg['size'] = size mean = getattr(module, 'image_mean', None) if mean is not None: preprocess_cfg['mean'] = mean std = getattr(module, 'image_std', None) if std is not None: preprocess_cfg['std'] = std return preprocess_cfg def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): module = getattr(model, 'visual', model) module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict def get_model_tokenize_cfg(model): module = getattr(model, 'text', model) cfg = {} context_length = getattr(module, 'context_length', None) if context_length is not None: cfg['context_length'] = context_length vocab_size = getattr(module, 'vocab_size', None) if vocab_size is not None: cfg['vocab_size'] = vocab_size return cfg ================================================ FILE: inf_clip/models/coca_arch.py ================================================ from typing import Optional import torch from torch import nn from torch.nn import functional as F import numpy as np from dataclasses import dataclass from .transformer import ( LayerNormFp32, LayerNorm, QuickGELU, MultimodalTransformer, ) from .clip_arch import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower try: from transformers import ( BeamSearchScorer, LogitsProcessorList, TopPLogitsWarper, TopKLogitsWarper, RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, StopStringCriteria, EosTokenCriteria, StoppingCriteriaList ) GENERATION_TYPES = { "top_k": TopKLogitsWarper, "top_p": TopPLogitsWarper, "beam_search": "beam_search" } _has_transformers = True except ImportError as e: GENERATION_TYPES = { "top_k": None, "top_p": None, "beam_search": "beam_search" } _has_transformers = False @dataclass class MultimodalCfg(CLIPTextCfg): mlp_ratio: int = 4 dim_head: int = 64 heads: int = 8 n_queries: int = 256 attn_pooler_heads: int = 8 def _build_text_decoder_tower( embed_dim, multimodal_cfg, quick_gelu: bool = False, cast_dtype: Optional[torch.dtype] = None, ): multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg act_layer = QuickGELU if quick_gelu else nn.GELU norm_layer = ( LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm ) decoder = MultimodalTransformer( context_length=multimodal_cfg.context_length, width=multimodal_cfg.width, heads=multimodal_cfg.heads, layers=multimodal_cfg.layers, ls_init_value=multimodal_cfg.ls_init_value, output_dim=embed_dim, act_layer=act_layer, norm_layer=norm_layer, ) return decoder def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: if not isinstance(token_id, torch.Tensor): if isinstance(token_id, int): token_id = [token_id] token_id = torch.tensor(token_id, device=device) return token_id class CoCa(nn.Module): arch_type: torch.jit.Final[str] = 'coca' def __init__( self, embed_dim, multimodal_cfg: MultimodalCfg, text_cfg: CLIPTextCfg, vision_cfg: CLIPVisionCfg, quick_gelu: bool = False, init_logit_scale: float = np.log(1 / 0.07), init_logit_bias: Optional[float] = None, cast_dtype: Optional[torch.dtype] = None, pad_id: int = 0, ): super().__init__() multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg self.text = _build_text_tower( embed_dim=embed_dim, text_cfg=text_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) vocab_size = ( text_cfg.vocab_size # for hf models if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None else text_cfg.vocab_size ) self.visual = _build_vision_tower( embed_dim=embed_dim, vision_cfg=vision_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.text_decoder = _build_text_decoder_tower( vocab_size, multimodal_cfg=multimodal_cfg, quick_gelu=quick_gelu, cast_dtype=cast_dtype, ) self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) if init_logit_bias is not None: self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) else: self.logit_bias = None self.pad_id = pad_id self.context_length = multimodal_cfg.context_length @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True): self.visual.set_grad_checkpointing(enable) self.text.set_grad_checkpointing(enable) self.text_decoder.set_grad_checkpointing(enable) def _encode_image(self, images, normalize: bool = True): image_latent, tokens_embs = self.visual(images) image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent return image_latent, tokens_embs def _encode_text(self, text, normalize: bool = True): text_latent, token_emb = self.text(text) text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent return text_latent, token_emb def encode_image(self, images, normalize: bool = True): image_latent, _ = self._encode_image(images, normalize=normalize) return image_latent def encode_text(self, text, normalize: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent def forward( self, image, text: Optional[torch.Tensor] = None, image_latent: Optional[torch.Tensor] = None, image_embs: Optional[torch.Tensor] = None, output_labels: bool = True, ): if image_latent is None or image_embs is None: image_latent, image_embs = self._encode_image(image) if text is None: return {"image_features": image_latent, "image_embs": image_embs} text_latent, token_embs = self._encode_text(text) # FIXME this isn't an ideal solution, would like to improve -RW labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None if output_labels: # align text_embs and thus logits with labels for teacher-forcing caption loss token_embs = token_embs[:, :-1] logits = self.text_decoder(image_embs, token_embs) out_dict = { "image_features": image_latent, "text_features": text_latent, "logits": logits, "logit_scale": self.logit_scale.exp() } if labels is not None: out_dict["labels"] = labels if self.logit_bias is not None: out_dict["logit_bias"] = self.logit_bias return out_dict def generate( self, image, text=None, seq_len=30, max_seq_len=77, temperature=1., generation_type="beam_search", top_p=0.1, # keep tokens in the 1 - top_p quantile top_k=1, # keeps the top_k most probable tokens pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, repetition_penalty=1.0, fixed_output_length=False # if True output.shape == (batch_size, seq_len) ): # taking many ideas and components from HuggingFace GenerationMixin # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" device = image.device with torch.no_grad(): sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) pad_token_id = self.pad_id if pad_token_id is None else pad_token_id logit_processor = LogitsProcessorList( [ MinLengthLogitsProcessor(min_seq_len, eos_token_id), RepetitionPenaltyLogitsProcessor(repetition_penalty), ] ) if stopping_criteria is None: stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] stopping_criteria = StoppingCriteriaList(stopping_criteria) if generation_type == "beam_search": output = self._generate_beamsearch( image_inputs=image, pad_token_id=pad_token_id, eos_token_id=eos_token_id, sot_token_id=sot_token_id, num_beams=num_beams, num_beam_groups=num_beam_groups, min_seq_len=min_seq_len, stopping_criteria=stopping_criteria, logit_processor=logit_processor, ) if fixed_output_length and output.shape[1] < seq_len: pad_len = seq_len - output.shape[1] return torch.cat(( output, torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id ), dim=1 ) return output elif generation_type == "top_p": logit_warper = GENERATION_TYPES[generation_type](top_p) elif generation_type == "top_k": logit_warper = GENERATION_TYPES[generation_type](top_k) else: raise ValueError( f"generation_type has to be one of " f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." ) image_latent, image_embs = self._encode_image(image) if text is None: text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id was_training = self.training num_dims = len(text.shape) if num_dims == 1: text = text[None, :] self.eval() out = text while True: x = out[:, -max_seq_len:] cur_len = x.shape[1] logits = self( image, x, image_latent=image_latent, image_embs=image_embs, output_labels=False, )["logits"][:, -1] mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id if mask.all(): if not fixed_output_length: break else: logits = logits[~mask, :] filtered_logits = logit_processor(x[~mask, :], logits) filtered_logits = logit_warper(x[~mask, :], filtered_logits) probs = F.softmax(filtered_logits / temperature, dim=-1) if (cur_len + 1 == seq_len): sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id else: sample[~mask, :] = torch.multinomial(probs, 1) out = torch.cat((out, sample), dim=-1) cur_len += 1 if all(stopping_criteria(out, None)): break if num_dims == 1: out = out.squeeze(0) self.train(was_training) return out def _generate_beamsearch( self, image_inputs, pad_token_id=None, eos_token_id=None, sot_token_id=None, num_beams=6, num_beam_groups=3, min_seq_len=5, stopping_criteria=None, logit_processor=None, logit_warper=None, ): device = image_inputs.device batch_size = image_inputs.shape[0] image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) image_latent, image_embs = self._encode_image(image_inputs) input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) input_ids = input_ids * sot_token_id beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=device, num_beam_groups=num_beam_groups, ) # instantiate logits processors logits_processor = ( LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) if logit_processor is None else logit_processor ) num_beams = beam_scorer.num_beams num_beam_groups = beam_scorer.num_beam_groups num_sub_beams = num_beams // num_beam_groups batch_size = len(beam_scorer._beam_hyps) // num_beam_groups batch_beam_size, cur_len = input_ids.shape beam_indices = None if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # the same group don't produce same tokens everytime. beam_scores[:, ::num_sub_beams] = 0 beam_scores = beam_scores.view((batch_size * num_beams,)) while True: # predicted tokens in cur_len step current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) # indices which will form the beams in the next time step reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) # do one decoder step on all beams of all sentences in batch model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) outputs = self( model_inputs['images'], model_inputs['text'], image_latent=image_latent, image_embs=image_embs, output_labels=False, ) for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams group_end_idx = min(group_start_idx + num_sub_beams, num_beams) group_size = group_end_idx - group_start_idx # indices of beams of current group among all sentences in batch batch_group_indices = [] for batch_idx in range(batch_size): batch_group_indices.extend( [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] ) group_input_ids = input_ids[batch_group_indices] # select outputs of beams of currentg group only next_token_logits = outputs['logits'][batch_group_indices, -1, :] vocab_size = next_token_logits.shape[-1] next_token_scores_processed = logits_processor( group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) # reshape for beam search next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores, next_tokens = torch.topk( next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True ) next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_tokens = next_tokens % vocab_size # stateless process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None beam_outputs = beam_scorer.process( group_input_ids, next_token_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=process_beam_indices, group_index=beam_group_idx, ) beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids[batch_group_indices] = group_input_ids[beam_idx] group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) current_tokens[batch_group_indices] = group_input_ids[:, -1] # (beam_idx // group_size) -> batch_idx # (beam_idx % group_size) -> offset of idx inside the group reordering_indices[batch_group_indices] = ( num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) ) input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) # increase cur_len cur_len = cur_len + 1 if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None sequence_outputs = beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=final_beam_indices, ) return sequence_outputs['sequences'] def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) else: position_ids = None return { "text": input_ids, "images": image_inputs, "past_key_values": past, "position_ids": position_ids, "attention_mask": attention_mask, } ================================================ FILE: inf_clip/models/hf_configs.py ================================================ # HF architecture dict: arch_dict = { # https://huggingface.co/docs/transformers/model_doc/roberta#roberta "roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig "xlm-roberta": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", "layer_attr": "layer", "token_embeddings_attr": "embeddings" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 "mt5": { "config_names": { # unlimited seqlen # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 "context_length": "", "vocab_size": "vocab_size", "width": "d_model", "heads": "num_heads", "layers": "num_layers", "layer_attr": "block", "token_embeddings_attr": "embed_tokens" }, "pooler": "mean_pooler", }, # https://huggingface.co/docs/transformers/model_doc/bert "bert": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "hidden_size", "heads": "num_attention_heads", "layers": "num_hidden_layers", }, "pooler": "cls_pooler", }, # https://huggingface.co/docs/transformers/model_doc/m2m_100 "m2m_100": { "config_names": { "context_length": "max_position_embeddings", "vocab_size": "vocab_size", "width": "d_model", "heads": "encoder_attention_heads", "layers": "encoder_layers", }, "pooler": "cls_pooler", }, } ================================================ FILE: inf_clip/models/hf_model.py ================================================ """ huggingface model adapter Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. """ import re from contextlib import nullcontext import torch import torch.nn as nn from torch import TensorType try: import transformers from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ BaseModelOutputWithPoolingAndCrossAttentions from transformers.modeling_utils import no_init_weights except ImportError as e: transformers = None class BaseModelOutput: pass class PretrainedConfig: pass from .hf_configs import arch_dict # utils def _camel2snake(s): return re.sub(r'(? torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) if self.world_size > 1 and self.local_loss: labels = labels + num_logits * self.rank if self.cache_labels: self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def get_logits(self, image_features, text_features, logit_scale): if self.world_size > 1: all_image_features, all_text_features = gather_features( image_features, text_features, self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) if self.local_loss: logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T else: logits_per_image = logit_scale * all_image_features @ all_text_features.T logits_per_text = logits_per_image.T else: logits_per_image = logit_scale * image_features @ text_features.T logits_per_text = logit_scale * text_features @ image_features.T return logits_per_image, logits_per_text def forward(self, image_features, text_features, logit_scale): device = image_features.device logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels, reduction='none') + F.cross_entropy(logits_per_text, labels, reduction='none') ) / 2 scale_factor = (total_loss.shape[0] / image_features.shape[0]) total_loss = torch.mean(total_loss * scale_factor) show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor) return {"contrastive_loss": total_loss, "show_loss": show_loss} class DiscoClipLoss(nn.Module): def __init__( self, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, num_logits) -> torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) labels = labels + num_logits * self.rank self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def forward(self, image_features, text_features, logit_scale): device = image_features.device all_image_features, all_text_features = gather_features( image_features, text_features, True, True, self.rank, self.world_size, self.use_horovod) logits_per_image = logit_scale * image_features @ all_text_features.T logits_per_text = logit_scale * text_features @ all_image_features.T labels = self.get_ground_truth(device, logits_per_image.shape[0]) total_loss = ( F.cross_entropy(logits_per_image, labels, reduction='none') + F.cross_entropy(logits_per_text, labels, reduction='none') ) / 2 total_loss = torch.mean(total_loss) show_loss = all_reduce(total_loss.detach().clone()) / self.world_size return {"contrastive_loss": total_loss, "show_loss": show_loss} class FlashClipLoss(nn.Module): def __init__( self, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, num_logits) -> torch.Tensor: # calculated ground-truth and cache if enabled if self.prev_num_logits != num_logits or device not in self.labels: labels = torch.arange(num_logits, device=device, dtype=torch.long) self.labels[device] = labels self.prev_num_logits = num_logits else: labels = self.labels[device] return labels def forward(self, image_features, text_features, logit_scale): device = image_features.device all_image_features, all_text_features = gather_features( image_features, text_features, False, False, self.rank, self.world_size, self.use_horovod) labels = self.get_ground_truth(device, all_image_features.shape[0]) i2t_loss = _cal_flash_loss(logit_scale * all_image_features, all_text_features, labels) t2i_loss = _cal_flash_loss(logit_scale * all_text_features, all_image_features, labels) total_loss = (i2t_loss + t2i_loss) / 2 scale_factor = (total_loss.shape[0] / image_features.shape[0]) total_loss = torch.mean(total_loss * scale_factor) show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor) return {"contrastive_loss": total_loss, "show_loss": show_loss} class RingClipLoss(nn.Module): def __init__( self, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.labels = {} def forward(self, image_features, text_features, logit_scale): device = image_features.device q = image_features k = text_features l = logit_scale if device in self.labels: labels = self.labels[device] else: labels = torch.arange(q.shape[0], device=device, dtype=torch.long) self.labels[device] = labels i2t_loss = cal_ring_loss(q, k, labels, l) t2i_loss = cal_ring_loss(k, q, labels, l) total_loss = (i2t_loss + t2i_loss) / 2 total_loss = total_loss.mean() show_loss = all_reduce(total_loss.detach().clone()) / self.world_size return {"contrastive_loss": total_loss, "show_loss": show_loss} class InfClipLoss(nn.Module): def __init__( self, rank=0, world_size=1, use_horovod=False, ): super().__init__() self.rank = rank self.world_size = world_size self.use_horovod = use_horovod # cache state self.labels = {} def forward(self, image_features, text_features, logit_scale): device = image_features.device q = image_features k = text_features l = logit_scale if device in self.labels: labels = self.labels[device] else: labels = torch.arange(q.shape[0], device=device, dtype=torch.long) self.labels[device] = labels i2t_loss = cal_inf_loss(q, k, labels, l) t2i_loss = cal_inf_loss(k, q, labels, l) total_loss = (i2t_loss + t2i_loss) / 2 total_loss = total_loss.mean() show_loss = all_reduce(total_loss.detach().clone()) / self.world_size return {"contrastive_loss": total_loss, "show_loss": show_loss} # NOTE: debug code for checkint ring loss gradient # rank = dist.get_rank() # q = image_features.detach().clone().requires_grad_() # k = text_features.detach().clone().requires_grad_() # l = logit_scale.detach().clone().requires_grad_() # all_q, all_k = gather_features( # q, k, # self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) # if rank == 0: # import numpy as np # np.save('all_q.npy', all_q.detach().cpu().numpy()) # np.save('all_k.npy', all_k.detach().cpu().numpy()) # labels = torch.arange(all_q.shape[0], device=device, dtype=torch.long) # qk = l * all_q @ all_k.T # lse = torch.logsumexp(qk, dim=1) # # numerator = torch.einsum("md,md->m",l * all_q, all_k[labels]) # numerator = qk[torch.arange(qk.shape[0]), labels] # i2t_loss = -numerator + lse # lse = torch.logsumexp(qk.T, dim=1) # # numerator = torch.einsum("md,md->m", l * all_k, all_q[labels]) # numerator = qk.T[labels, torch.arange(qk.shape[0])] # t2i_loss = -numerator + lse # # i2t_loss = F.cross_entropy(qk, labels, reduction='none') # # t2i_loss = F.cross_entropy(qk.T, labels, reduction='none') # total_loss = (i2t_loss + t2i_loss) / 2 # total_loss.sum().backward() # q1 = image_features.detach().clone().requires_grad_() # k1 = text_features.detach().clone().requires_grad_() # l1 = logit_scale.detach().clone().requires_grad_() # labels = torch.arange(q1.shape[0], device=device, dtype=torch.long) # i2t_loss1 = _cal_ring_loss(l1 * q1, k1, labels) # t2i_loss1 = _cal_ring_loss(l1 * k1, q1, labels) # total_loss1 = (i2t_loss1 + t2i_loss1) / 2 # q1.retain_grad(); k1.retain_grad(); l1.retain_grad() # total_loss1.sum().backward() # if rank == 1: # import numpy as np # np.save('q_r1_grad.npy', q.grad.detach().cpu().numpy()) # np.save('q1_r1_grad.npy', q1.grad.detach().cpu().numpy()) # print(q.grad, q1.grad) # print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)), torch.max(torch.abs(l.grad - l1.grad))) # exit(0) class CoCaLoss(ClipLoss): def __init__( self, caption_loss_weight, clip_loss_weight, pad_id=0, # pad_token for open_clip custom tokenizer local_loss=False, gather_with_grad=False, cache_labels=False, rank=0, world_size=1, use_horovod=False, ): super().__init__( local_loss=local_loss, gather_with_grad=gather_with_grad, cache_labels=cache_labels, rank=rank, world_size=world_size, use_horovod=use_horovod ) self.clip_loss_weight = clip_loss_weight self.caption_loss_weight = caption_loss_weight self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): clip_loss = torch.tensor(0) if self.clip_loss_weight: clip_loss = super().forward(image_features, text_features, logit_scale) clip_loss = self.clip_loss_weight * clip_loss caption_loss = self.caption_loss( logits.permute(0, 2, 1), labels, ) caption_loss = caption_loss * self.caption_loss_weight if output_dict: return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} return clip_loss, caption_loss class DistillClipLoss(ClipLoss): def dist_loss(self, teacher_logits, student_logits): return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) def forward( self, image_features, text_features, logit_scale, dist_image_features, dist_text_features, dist_logit_scale, output_dict=False, ): logits_per_image, logits_per_text = \ self.get_logits(image_features, text_features, logit_scale) dist_logits_per_image, dist_logits_per_text = \ self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) contrastive_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 distill_loss = ( self.dist_loss(dist_logits_per_image, logits_per_image) + self.dist_loss(dist_logits_per_text, logits_per_text) ) / 2 if output_dict: return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} return contrastive_loss, distill_loss def neighbour_exchange(from_rank, to_rank, tensor, group=None): tensor_recv = torch.zeros_like(tensor) send_op = torch.distributed.P2POp( torch.distributed.isend, tensor, to_rank, group=group, ) recv_op = torch.distributed.P2POp( torch.distributed.irecv, tensor_recv, from_rank, group=group, ) reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) for req in reqs: req.wait() return tensor_recv def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): tensor_from_left = torch.zeros_like(tensor_to_right) tensor_from_right = torch.zeros_like(tensor_to_left) send_op_left = torch.distributed.P2POp( torch.distributed.isend, tensor_to_left, left_rank, group=group, ) send_op_right = torch.distributed.P2POp( torch.distributed.isend, tensor_to_right, right_rank, group=group, ) recv_op_left = torch.distributed.P2POp( torch.distributed.irecv, tensor_from_left, left_rank, group=group, ) recv_op_right = torch.distributed.P2POp( torch.distributed.irecv, tensor_from_right, right_rank, group=group, ) reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) for req in reqs: req.wait() return tensor_from_right, tensor_from_left class NeighbourExchange(torch.autograd.Function): @staticmethod def forward(ctx, from_rank, to_rank, group, tensor): ctx.group = group ctx.from_rank = from_rank ctx.to_rank = to_rank return neighbour_exchange(from_rank, to_rank, tensor, group=group) @staticmethod def backward(ctx, grad_output): return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): return NeighbourExchange.apply(from_rank, to_rank, group, tensor) class NeighbourExchangeBidir(torch.autograd.Function): @staticmethod def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): ctx.group = group ctx.left_rank = left_rank ctx.right_rank = right_rank return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) @staticmethod def backward(ctx, *grad_outputs): return (None, None, None) + \ NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) class SigLipLoss(nn.Module): """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 @article{zhai2023sigmoid, title={Sigmoid loss for language image pre-training}, author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, journal={arXiv preprint arXiv:2303.15343}, year={2023} } """ def __init__( self, cache_labels=False, rank=0, world_size=1, bidir=True, use_horovod=False, ): super().__init__() self.cache_labels = cache_labels self.rank = rank self.world_size = world_size assert not use_horovod # FIXME need to look at hvd ops for ring transfers self.use_horovod = use_horovod self.bidir = bidir # cache state FIXME cache not currently used, worthwhile? self.prev_num_logits = 0 self.labels = {} def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) if not negative_only: labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels return labels def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): logits = logit_scale * image_features @ text_features.T if logit_bias is not None: logits += logit_bias return logits def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) labels = self.get_ground_truth( image_features.device, image_features.dtype, image_features.shape[0], negative_only=negative_only, ) loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] return loss def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): loss = self._loss(image_features, text_features, logit_scale, logit_bias) if self.world_size > 1: # exchange text features w/ neighbour world_size - 1 times right_rank = (self.rank + 1) % self.world_size left_rank = (self.rank - 1 + self.world_size) % self.world_size if self.bidir: text_features_to_right = text_features_to_left = text_features num_bidir, remainder = divmod(self.world_size - 1, 2) for i in range(num_bidir): text_features_recv = neighbour_exchange_bidir_with_grad( left_rank, right_rank, text_features_to_left, text_features_to_right, ) for f in text_features_recv: loss += self._loss( image_features, f, logit_scale, logit_bias, negative_only=True, ) text_features_to_left, text_features_to_right = text_features_recv if remainder: text_features_recv = neighbour_exchange_with_grad( left_rank, right_rank, text_features_to_right) loss += self._loss( image_features, text_features_recv, logit_scale, logit_bias, negative_only=True, ) else: text_features_to_right = text_features for i in range(self.world_size - 1): text_features_from_left = neighbour_exchange_with_grad( left_rank, right_rank, text_features_to_right) loss += self._loss( image_features, text_features_from_left, logit_scale, logit_bias, negative_only=True, ) text_features_to_right = text_features_from_left return {"contrastive_loss": loss} if output_dict else loss ================================================ FILE: inf_clip/models/modified_resnet.py ================================================ from collections import OrderedDict import torch from torch import nn from torch.nn import functional as F from ..utils import freeze_batch_norm_2d class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1): super().__init__() # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.act2 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.act3 = nn.ReLU(inplace=True) self.downsample = None self.stride = stride if stride > 1 or inplanes != planes * Bottleneck.expansion: # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 self.downsample = nn.Sequential(OrderedDict([ ("-1", nn.AvgPool2d(stride)), ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), ("1", nn.BatchNorm2d(planes * self.expansion)) ])) def forward(self, x: torch.Tensor): identity = x out = self.act1(self.bn1(self.conv1(x))) out = self.act2(self.bn2(self.conv2(out))) out = self.avgpool(out) out = self.bn3(self.conv3(out)) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.act3(out) return out class AttentionPool2d(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( query=x, key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0., out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False ) return x[0] class ModifiedResNet(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool """ def __init__(self, layers, output_dim, heads, image_size=224, width=64): super().__init__() self.output_dim = output_dim self.image_size = image_size # the 3-layer stem self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(width // 2) self.act1 = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(width // 2) self.act2 = nn.ReLU(inplace=True) self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(width) self.act3 = nn.ReLU(inplace=True) self.avgpool = nn.AvgPool2d(2) # residual layers self._inplanes = width # this is a *mutable* variable used during construction self.layer1 = self._make_layer(width, layers[0]) self.layer2 = self._make_layer(width * 2, layers[1], stride=2) self.layer3 = self._make_layer(width * 4, layers[2], stride=2) self.layer4 = self._make_layer(width * 8, layers[3], stride=2) embed_dim = width * 32 # the ResNet feature dimension self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) self.init_parameters() def _make_layer(self, planes, blocks, stride=1): layers = [Bottleneck(self._inplanes, planes, stride)] self._inplanes = planes * Bottleneck.expansion for _ in range(1, blocks): layers.append(Bottleneck(self._inplanes, planes)) return nn.Sequential(*layers) def init_parameters(self): if self.attnpool is not None: std = self.attnpool.c_proj.in_features ** -0.5 nn.init.normal_(self.attnpool.q_proj.weight, std=std) nn.init.normal_(self.attnpool.k_proj.weight, std=std) nn.init.normal_(self.attnpool.v_proj.weight, std=std) nn.init.normal_(self.attnpool.c_proj.weight, std=std) for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: for name, param in resnet_block.named_parameters(): if name.endswith("bn3.weight"): nn.init.zeros_(param) def lock(self, unlocked_groups=0, freeze_bn_stats=False): assert unlocked_groups == 0, 'partial locking not currently supported for this model' for param in self.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): # FIXME support for non-transformer pass def stem(self, x): x = self.act1(self.bn1(self.conv1(x))) x = self.act2(self.bn2(self.conv2(x))) x = self.act3(self.bn3(self.conv3(x))) x = self.avgpool(x) return x def forward(self, x): x = self.stem(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.attnpool(x) return x ================================================ FILE: inf_clip/models/pos_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Position embedding utils # -------------------------------------------------------- import numpy as np import torch # -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed ================================================ FILE: inf_clip/models/timm_model.py ================================================ """ timm model adapter Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. """ import logging from collections import OrderedDict import torch import torch.nn as nn try: import timm from timm.models.layers import Mlp, to_2tuple try: # old timm imports < 0.8.1 from timm.models.layers.attention_pool2d import RotAttentionPool2d from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d except ImportError: # new timm imports >= 0.8.1 from timm.layers import RotAttentionPool2d from timm.layers import AttentionPool2d as AbsAttentionPool2d except ImportError: timm = None from ..utils import freeze_batch_norm_2d class TimmModel(nn.Module): """ timm model adapter """ def __init__( self, model_name, embed_dim, image_size=224, pool='avg', proj='linear', proj_bias=False, drop=0., drop_path=None, patch_drop=None, pretrained=False, ): super().__init__() if timm is None: raise RuntimeError("Please `pip install timm` to use timm models.") self.image_size = to_2tuple(image_size) # setup kwargs that may not be common across all models timm_kwargs = {} if drop_path is not None: timm_kwargs['drop_path_rate'] = drop_path if patch_drop is not None: timm_kwargs['patch_drop_rate'] = patch_drop custom_pool = pool in ('abs_attn', 'rot_attn') if proj: assert proj in ("linear", "mlp", "none") extra_proj = proj in ("linear", "mlp") if not extra_proj and not custom_pool: # use network classifier head as projection if no proj specified and no custom pooling used # if projection is explicitly set to "none" will be pass through from network trunk proj_dim = 0 if proj == 'none' else embed_dim self.trunk = timm.create_model( model_name, num_classes=proj_dim, global_pool=pool, pretrained=pretrained, **timm_kwargs, ) prev_chs = embed_dim else: self.trunk = timm.create_model( model_name, pretrained=pretrained, **timm_kwargs, ) feat_size = self.trunk.default_cfg.get('pool_size', None) feature_ndim = 1 if not feat_size else 2 if custom_pool: assert feature_ndim == 2 # if attn pooling used, remove both classifier and default pool self.trunk.reset_classifier(0, global_pool='') else: # reset global pool if pool config set, otherwise leave as network default reset_kwargs = dict(global_pool=pool) if pool else {} self.trunk.reset_classifier(0, **reset_kwargs) prev_chs = self.trunk.num_features head_layers = OrderedDict() # Add custom pooling to head if pool == 'abs_attn': head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) prev_chs = embed_dim elif pool == 'rot_attn': head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) prev_chs = embed_dim # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used if proj == 'linear': head_layers['drop'] = nn.Dropout(drop) head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) elif proj == 'mlp': head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) self.head = nn.Sequential(head_layers) def lock(self, unlocked_groups=0, freeze_bn_stats=False): """ lock modules Args: unlocked_groups (int): leave last n layer groups unlocked (default: 0) """ if not unlocked_groups: # lock full model for param in self.trunk.parameters(): param.requires_grad = False if freeze_bn_stats: freeze_batch_norm_2d(self.trunk) else: # NOTE: partial freeze requires latest timm (master) branch and is subject to change try: # FIXME import here until API stable and in an official release from timm.models.helpers import group_parameters, group_modules except ImportError: raise RuntimeError( 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') matcher = self.trunk.group_matcher() gparams = group_parameters(self.trunk, matcher) max_layer_id = max(gparams.keys()) max_layer_id = max_layer_id - unlocked_groups for group_idx in range(max_layer_id + 1): group = gparams[group_idx] for param in group: self.trunk.get_parameter(param).requires_grad = False if freeze_bn_stats: gmodules = group_modules(self.trunk, matcher, reverse=True) gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} freeze_batch_norm_2d(self.trunk, gmodules) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): try: self.trunk.set_grad_checkpointing(enable) except Exception as e: logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') def forward_trunk(self, x): return self.trunk(x) def forward_head(self, x): return self.head(x) def forward(self, x): x = self.forward_trunk(x) x = self.forward_head(x) return x ================================================ FILE: inf_clip/models/tokenizer.py ================================================ """ CLIP tokenizer Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import gzip import html import os import random import string from functools import lru_cache, partial from typing import Callable, List, Optional, Union import warnings import ftfy import numpy as np import regex as re import torch # https://stackoverflow.com/q/62691279 os.environ["TOKENIZERS_PARALLELISM"] = "false" _nltk_init = False DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP @lru_cache() def default_bpe(): return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") @lru_cache() def bytes_to_unicode(): """ Returns list of utf-8 byte and a corresponding list of unicode strings. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on. """ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) cs = bs[:] n = 0 for b in range(2**8): if b not in bs: bs.append(b) cs.append(2**8+n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) def get_pairs(word): """Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() prev_char = word[0] for char in word[1:]: pairs.add((prev_char, char)) prev_char = char return pairs def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): text = " ".join(text.split()) text = text.strip() return text def _clean_canonicalize(x): # basic, remove whitespace, remove punctuation, lower case return canonicalize_text(basic_clean(x)) def _clean_lower(x): # basic, remove whitespace, lower case return whitespace_clean(basic_clean(x)).lower() def _clean_whitespace(x): # basic, remove whitespace return whitespace_clean(basic_clean(x)) def get_clean_fn(type: str): if type == 'canonicalize': return _clean_canonicalize elif type == 'lower': return _clean_lower elif type == 'whitespace': return _clean_whitespace else: assert False, f"Invalid clean function ({type})." def canonicalize_text( text, *, keep_punctuation_exact_string=None, trans_punctuation: dict = str.maketrans("", "", string.punctuation), ): """Returns canonicalized `text` (lowercase and punctuation removed). From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 Args: text: string to be canonicalized. keep_punctuation_exact_string: If provided, then this exact string kept. For example providing '{}' will keep any occurrences of '{}' (but will still remove '{' and '}' that appear separately). """ text = text.replace("_", " ") if keep_punctuation_exact_string: text = keep_punctuation_exact_string.join( part.translate(trans_punctuation) for part in text.split(keep_punctuation_exact_string) ) else: text = text.translate(trans_punctuation) text = text.lower() text = " ".join(text.split()) return text.strip() class SimpleTokenizer(object): def __init__( self, bpe_path: str = default_bpe(), additional_special_tokens: Optional[List[str]] = None, context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, clean: str = 'lower', reduction_mask: str = '' ): self.byte_encoder = bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') merges = merges[1:49152-256-2+1] merges = [tuple(merge.split()) for merge in merges] vocab = list(bytes_to_unicode().values()) vocab = vocab + [v+'' for v in vocab] for merge in merges: vocab.append(''.join(merge)) special_tokens = ['', ''] if additional_special_tokens: special_tokens += additional_special_tokens vocab.extend(special_tokens) self.encoder = dict(zip(vocab, range(len(vocab)))) self.decoder = {v: k for k, v in self.encoder.items()} self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {t:t for t in special_tokens} special = "|".join(special_tokens) self.pat = re.compile( special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE, ) self.vocab_size = len(self.encoder) self.all_special_ids = [self.encoder[t] for t in special_tokens] self.sot_token_id = self.all_special_ids[0] self.eot_token_id = self.all_special_ids[1] self.context_length = context_length self.clean_fn = get_clean_fn(clean) self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token[:-1]) + ( token[-1] + '',) pairs = get_pairs(word) if not pairs: return token+'' while True: bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) if bigram not in self.bpe_ranks: break first, second = bigram new_word = [] i = 0 while i < len(word): try: j = word.index(first, i) new_word.extend(word[i:j]) i = j except Exception: new_word.extend(word[i:]) break if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second) i += 2 else: new_word.append(word[i]) i += 1 new_word = tuple(new_word) word = new_word if len(word) == 1: break else: pairs = get_pairs(word) word = ' '.join(word) self.cache[token] = word return word def encode(self, text): bpe_tokens = [] text = self.clean_fn(text) for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') return text def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length Returns ------- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] """ if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length assert context_length, 'Please set a valid context length' if self.reduction_fn is not None: # use reduction strategy for tokenize if set, otherwise default to truncation below return self.reduction_fn( texts, context_length=context_length, sot_token_id=self.sot_token_id, eot_token_id=self.eot_token_id, encode_fn=self.encode, ) all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = self.eot_token_id result[i, :len(tokens)] = torch.tensor(tokens) return result _tokenizer = SimpleTokenizer() def decode(output_ids: torch.Tensor): output_ids = output_ids.cpu().numpy() return _tokenizer.decode(output_ids) def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: return _tokenizer(texts, context_length=context_length) def random_mask_tokenize( texts: Union[str, List[str]], context_length: int, sot_token_id: int, eot_token_id: int, encode_fn: Callable, shuffle: bool = False, ): all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): tokens = torch.tensor(tokens) num_tokens = len(tokens) if num_tokens > context_length - 2: # 2 for sot and eot token num_keep = context_length - 2 indices = torch.randperm(len(tokens)) indices = indices[:num_keep] if not shuffle: indices = indices.msort() tokens = tokens[indices] num_tokens = num_keep result[i, 0] = sot_token_id result[i, 1:num_tokens + 1] = tokens result[i, num_tokens + 1] = eot_token_id return result def simple_mask_tokenize( texts: Union[str, List[str]], context_length: int, sot_token_id: int, eot_token_id: int, encode_fn: Callable, ): all_tokens = [encode_fn(text) for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): num_tokens = len(tokens) if num_tokens > context_length - 2: # 2 for sot and eot token num_keep = context_length - 2 start_index = random.randint(0, num_tokens - num_keep) # high is incl tokens = tokens[start_index: start_index + num_keep] tokens = [sot_token_id] + tokens + [eot_token_id] result[i, :len(tokens)] = torch.tensor(tokens) return result def syntax_mask_tokenize( texts: Union[str, List[str]], context_length: int, sot_token_id: int, eot_token_id: int, encode_fn: Callable, ) -> torch.LongTensor: """ Returns the tokenized representation of given input string(s). Apply syntax masking before tokenize. """ import nltk global _nltk_init if not _nltk_init: # run them for the first time nltk.download('punkt') nltk.download('averaged_perceptron_tagger') _nltk_init = True def get_order(x): if x.startswith('NN'): return 1 elif x.startswith('JJ'): return 2 elif x.startswith('VB'): return 3 else: return 4 # syntax masking new_texts = [] for text in texts: list_tokens = nltk.tokenize.word_tokenize(text) pos_tags = nltk.pos_tag(list_tokens) # sample the words by get_order method order_list = [get_order(tag) for _, tag in pos_tags] sorted_ids = np.argsort(np.array(order_list)) sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens new_text = '' for token in sampled_tokens: new_text = new_text + str(token) + ' ' new_text = new_text.strip() new_texts.append(new_text) texts = new_texts all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): # still need first truncate because some words produces two tokens if len(tokens) > context_length: tokens = tokens[:context_length] # Truncate tokens[-1] = eot_token_id result[i, :len(tokens)] = torch.tensor(tokens) return result def get_reduction_mask_fn(type: str): """ Choose strategy for dropping (masking) tokens to achieve target context length""" assert type in ('simple', 'random', 'shuffle', 'syntax') if type == 'simple': return simple_mask_tokenize # randomly select block [start:end] elif type == 'random': return random_mask_tokenize # randomly drop tokens (keep order) elif type == 'shuffle': return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) elif type == 'syntax': return syntax_mask_tokenize # randomly drop prioritized by syntax class HFTokenizer: """HuggingFace tokenizer wrapper""" def __init__( self, tokenizer_name: str, context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, clean: str = 'whitespace', strip_sep_token: bool = False, language: Optional[str] = None, **kwargs ): from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs) set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) if callable(set_lang_fn): self.set_lang_fn = set_lang_fn if language is not None: self.set_language(language) self.context_length = context_length self.clean_fn = get_clean_fn(clean) self.strip_sep_token = strip_sep_token def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length assert context_length, 'Please set a valid context length in class init or call.' texts = [self.clean_fn(text) for text in texts] input_ids = self.tokenizer.batch_encode_plus( texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True, ).input_ids if self.strip_sep_token: input_ids = torch.where( input_ids == self.tokenizer.sep_token_id, torch.zeros_like(input_ids), input_ids, ) return input_ids def set_language(self, src_lang): if hasattr(self, 'set_lang_fn'): self.set_lang_fn(src_lang) else: warnings.warn('Cannot set language for the tokenizer.') class SigLipTokenizer: """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs """ VOCAB_FILES = { # english, vocab_size=32_000 "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model", # used in multilingual models (mT5, PaLI), vocab_size=250_000 "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", } def __init__( self, tokenizer_name: str, context_length: Optional[int] = 64, ): from transformers import T5TokenizerFast if tokenizer_name in self.VOCAB_FILES: # FIXME temporary hack? import tempfile import fsspec vocab_file = self.VOCAB_FILES[tokenizer_name] with tempfile.NamedTemporaryFile('wb') as dst: with fsspec.open(vocab_file, 'rb') as src: dst.write(src.read()) self.tokenizer = T5TokenizerFast(dst.name, legacy=False) else: self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False) self.tokenizer.pad_token_id = 1 self.tokenizer.eos_token_id = 1 self.context_length = context_length def save_pretrained(self, dest): self.tokenizer.save_pretrained(dest) def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: # same cleaning as for default tokenizer, except lowercasing # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance if isinstance(texts, str): texts = [texts] context_length = context_length or self.context_length assert context_length, 'Please set a valid context length in class init or call.' texts = [canonicalize_text(basic_clean(text)) for text in texts] output = self.tokenizer( texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True, ) return output.input_ids ================================================ FILE: inf_clip/models/transform.py ================================================ import numbers import random import warnings from dataclasses import dataclass, asdict from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torchvision.transforms.functional as F from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ CenterCrop, ColorJitter, Grayscale from ..constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from ..utils import to_2tuple @dataclass class PreprocessCfg: size: Union[int, Tuple[int, int]] = 224 mode: str = 'RGB' mean: Tuple[float, ...] = OPENAI_DATASET_MEAN std: Tuple[float, ...] = OPENAI_DATASET_STD interpolation: str = 'bicubic' resize_mode: str = 'shortest' fill_color: int = 0 def __post_init__(self): assert self.mode in ('RGB',) @property def num_channels(self): return 3 @property def input_size(self): return (self.num_channels,) + to_2tuple(self.size) _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) def merge_preprocess_dict( base: Union[PreprocessCfg, Dict], overlay: Dict, ): """ Merge overlay key-value pairs on top of base preprocess cfg or dict. Input dicts are filtered based on PreprocessCfg fields. """ if isinstance(base, PreprocessCfg): base_clean = asdict(base) else: base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} if overlay: overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} base_clean.update(overlay_clean) return base_clean def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): return merge_preprocess_dict(base, kwargs) @dataclass class AugmentationCfg: scale: Tuple[float, float] = (0.9, 1.0) ratio: Optional[Tuple[float, float]] = None color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None re_prob: Optional[float] = None re_count: Optional[int] = None use_timm: bool = False # params for simclr_jitter_gray color_jitter_prob: float = None gray_scale_prob: float = None def _setup_size(size, error_msg): if isinstance(size, numbers.Number): return int(size), int(size) if isinstance(size, Sequence) and len(size) == 1: return size[0], size[0] if len(size) != 2: raise ValueError(error_msg) return size class ResizeKeepRatio: """ Resize and Keep Ratio Copy & paste from `timm` """ def __init__( self, size, longest=0., interpolation=InterpolationMode.BICUBIC, random_scale_prob=0., random_scale_range=(0.85, 1.05), random_aspect_prob=0., random_aspect_range=(0.9, 1.11) ): if isinstance(size, (list, tuple)): self.size = tuple(size) else: self.size = (size, size) self.interpolation = interpolation self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest self.random_scale_prob = random_scale_prob self.random_scale_range = random_scale_range self.random_aspect_prob = random_aspect_prob self.random_aspect_range = random_aspect_range @staticmethod def get_params( img, target_size, longest, random_scale_prob=0., random_scale_range=(0.85, 1.05), random_aspect_prob=0., random_aspect_range=(0.9, 1.11) ): """Get parameters """ source_size = img.size[::-1] # h, w h, w = source_size target_h, target_w = target_size ratio_h = h / target_h ratio_w = w / target_w ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) if random_scale_prob > 0 and random.random() < random_scale_prob: ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) ratio_factor = (ratio_factor, ratio_factor) else: ratio_factor = (1., 1.) if random_aspect_prob > 0 and random.random() < random_aspect_prob: aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] return size def __call__(self, img): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size """ size = self.get_params( img, self.size, self.longest, self.random_scale_prob, self.random_scale_range, self.random_aspect_prob, self.random_aspect_range ) img = F.resize(img, size, self.interpolation) return img def __repr__(self): format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += f', interpolation={self.interpolation})' format_string += f', longest={self.longest:.3f})' return format_string def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: """Center crops and/or pads the given image. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: img (PIL Image or Tensor): Image to be cropped. output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, it is used for both directions. fill (int, Tuple[int]): Padding color Returns: PIL Image or Tensor: Cropped image. """ if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: output_size = (output_size[0], output_size[0]) _, image_height, image_width = F.get_dimensions(img) crop_height, crop_width = output_size if crop_width > image_width or crop_height > image_height: padding_ltrb = [ (crop_width - image_width) // 2 if crop_width > image_width else 0, (crop_height - image_height) // 2 if crop_height > image_height else 0, (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, ] img = F.pad(img, padding_ltrb, fill=fill) _, image_height, image_width = F.get_dimensions(img) if crop_width == image_width and crop_height == image_height: return img crop_top = int(round((image_height - crop_height) / 2.0)) crop_left = int(round((image_width - crop_width) / 2.0)) return F.crop(img, crop_top, crop_left, crop_height, crop_width) class CenterCropOrPad(torch.nn.Module): """Crops the given image at the center. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). """ def __init__(self, size, fill=0): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.fill = fill def forward(self, img): """ Args: img (PIL Image or Tensor): Image to be cropped. Returns: PIL Image or Tensor: Cropped image. """ return center_crop_or_pad(img, self.size, fill=self.fill) def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" def _convert_to_rgb(image): return image.convert('RGB') class color_jitter(object): """ Apply Color Jitter to the PIL image with a specified probability. """ def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): assert 0. <= p <= 1. self.p = p self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img class gray_scale(object): """ Apply Gray Scale to the PIL image with a specified probability. """ def __init__(self, p=0.2): assert 0. <= p <= 1. self.p = p self.transf = Grayscale(num_output_channels=3) def __call__(self, img): if random.random() < self.p: return self.transf(img) else: return img def image_transform( image_size: Union[int, Tuple[int, int]], is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, resize_mode: Optional[str] = None, interpolation: Optional[str] = None, fill_color: int = 0, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 std = std or OPENAI_DATASET_STD if not isinstance(std, (list, tuple)): std = (std,) * 3 interpolation = interpolation or 'bicubic' assert interpolation in ['bicubic', 'bilinear', 'random'] # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC resize_mode = resize_mode or 'shortest' assert resize_mode in ('shortest', 'longest', 'squash') if isinstance(aug_cfg, dict): aug_cfg = AugmentationCfg(**aug_cfg) else: aug_cfg = aug_cfg or AugmentationCfg() normalize = Normalize(mean=mean, std=std) if is_train: aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} use_timm = aug_cfg_dict.pop('use_timm', False) if use_timm: from timm.data import create_transform # timm can still be optional if isinstance(image_size, (tuple, list)): assert len(image_size) >= 2 input_size = (3,) + image_size[-2:] else: input_size = (3, image_size, image_size) aug_cfg_dict.setdefault('color_jitter', None) # disable by default # drop extra non-timm items aug_cfg_dict.pop('color_jitter_prob', None) aug_cfg_dict.pop('gray_scale_prob', None) train_transform = create_transform( input_size=input_size, is_training=True, hflip=0., mean=mean, std=std, re_mode='pixel', interpolation=interpolation, **aug_cfg_dict, ) else: train_transform = [ RandomResizedCrop( image_size, scale=aug_cfg_dict.pop('scale'), interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ] if aug_cfg.color_jitter_prob: assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 train_transform.extend([ color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) ]) if aug_cfg.gray_scale_prob: train_transform.extend([ gray_scale(aug_cfg.gray_scale_prob) ]) train_transform.extend([ ToTensor(), normalize, ]) train_transform = Compose(train_transform) if aug_cfg_dict: warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') return train_transform else: if resize_mode == 'longest': transforms = [ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), CenterCropOrPad(image_size, fill=fill_color) ] elif resize_mode == 'squash': if isinstance(image_size, int): image_size = (image_size, image_size) transforms = [ Resize(image_size, interpolation=interpolation_mode), ] else: assert resize_mode == 'shortest' if not isinstance(image_size, (tuple, list)): image_size = (image_size, image_size) if image_size[0] == image_size[1]: # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) transforms = [ Resize(image_size[0], interpolation=interpolation_mode) ] else: # resize shortest edge to matching target dim for non-square target transforms = [ResizeKeepRatio(image_size)] transforms += [CenterCrop(image_size)] transforms.extend([ _convert_to_rgb, ToTensor(), normalize, ]) return Compose(transforms) def image_transform_v2( cfg: PreprocessCfg, is_train: bool, aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, ): return image_transform( image_size=cfg.size, is_train=is_train, mean=cfg.mean, std=cfg.std, interpolation=cfg.interpolation, resize_mode=cfg.resize_mode, fill_color=cfg.fill_color, aug_cfg=aug_cfg, ) ================================================ FILE: inf_clip/models/transformer.py ================================================ from collections import OrderedDict import math from typing import Callable, List, Optional, Sequence, Tuple, Union from functools import partial import torch from torch import nn from torch.nn import functional as F from torch.utils.checkpoint import checkpoint from .pos_embed import get_2d_sincos_pos_embed from ..utils import to_2tuple class LayerNormFp32(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm (with cast back to input dtype).""" def forward(self, x: torch.Tensor): orig_type = x.dtype x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return x.to(orig_type) class QuickGELU(nn.Module): # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__(self, prob, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token def forward(self, x): if not self.training or self.prob == 0.: return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) return x class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, scaled_cosine: bool = False, scale_heads: bool = False, logit_scale_max: float = math.log(1. / 0.01), batch_first: bool = True, attn_drop: float = 0., proj_drop: float = 0. ): super().__init__() self.scaled_cosine = scaled_cosine self.scale_heads = scale_heads assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.logit_scale_max = logit_scale_max self.batch_first = batch_first self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) if qkv_bias: self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) else: self.in_proj_bias = None if self.scaled_cosine: self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) else: self.logit_scale = None self.attn_drop = nn.Dropout(attn_drop) if self.scale_heads: self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.head_scale = None self.out_proj = nn.Linear(dim, dim) self.out_drop = nn.Dropout(proj_drop) def forward(self, x, attn_mask: Optional[torch.Tensor] = None): if self.batch_first: x = x.transpose(0, 1) L, N, C = x.shape q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) if attn_mask is not None and attn_mask.dtype == torch.bool: new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) new_attn_mask.masked_fill_(attn_mask, float("-inf")) attn_mask = new_attn_mask if self.logit_scale is not None: attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() attn = attn.view(N, self.num_heads, L, L) * logit_scale attn = attn.view(-1, L, L) if attn_mask is not None: attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) else: if self.use_fsdpa: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = torch.bmm(q, k.transpose(-1, -2)) if attn_mask is not None: attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = torch.bmm(attn, v) if self.head_scale is not None: x = x.view(N, self.num_heads, L, C) * self.head_scale x = x.view(-1, L, C) x = x.transpose(0, 1).reshape(L, N, C) if self.batch_first: x = x.transpose(0, 1) x = self.out_proj(x) x = self.out_drop(x) return x class AttentionalPooler(nn.Module): def __init__( self, d_model: int, context_dim: int, n_head: int = 8, n_queries: int = 256, norm_layer: Callable = LayerNorm, ): super().__init__() self.query = nn.Parameter(torch.randn(n_queries, d_model)) self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) self.ln_q = norm_layer(d_model) self.ln_k = norm_layer(context_dim) def forward(self, x: torch.Tensor): N = x.shape[0] x = self.ln_k(x) q = self.ln_q(self.query) out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] return out class ResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, is_cross_attention: bool = False, batch_first: bool = True, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() if is_cross_attention: self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn( q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask )[0] def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x class CustomResidualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, scale_cosine_attn: bool = False, scale_heads: bool = False, scale_attn: bool = False, scale_fc: bool = False, batch_first: bool = True, ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = Attention( d_model, n_head, scaled_cosine=scale_cosine_attn, scale_heads=scale_heads, batch_first=batch_first, ) self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() def get_reference_weight(self): return self.mlp.c_fc.weight def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.ls_2(self.mlp(self.ln_2(x))) return x def _expand_token(token, batch_size: int): return token.view(1, 1, -1).expand(batch_size, -1, -1) class Transformer(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, batch_first: bool = True, ): super().__init__() self.width = width self.layers = layers self.batch_first = batch_first self.grad_checkpointing = False self.resblocks = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): return self.resblocks[0].mlp.c_fc.int8_original_dtype return self.resblocks[0].mlp.c_fc.weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if not self.batch_first: x = x.transpose(0, 1).contiguous() # NLD -> LND for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) if not self.batch_first: x = x.transpose(0, 1) # LND -> NLD return x class CustomTransformer(nn.Module): """ A custom transformer that can use different block types. """ def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, batch_first: bool = True, block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', ): super().__init__() self.width = width self.layers = layers self.batch_first = batch_first # run trasnformer stack in batch first (N, L, D) self.grad_checkpointing = False if isinstance(block_types, str): block_types = [block_types] * layers assert len(block_types) == layers def _create_block(bt: str): if bt == 'CustomResidualAttentionBlock': return CustomResidualAttentionBlock( width, heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) else: assert False self.resblocks = nn.ModuleList([ _create_block(bt) for bt in block_types ]) def get_cast_dtype(self) -> torch.dtype: weight = self.resblocks[0].get_reference_weight() if hasattr(weight, 'int8_original_dtype'): return weight.int8_original_dtype return weight.dtype def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if not self.batch_first: x = x.transpose(0, 1) # NLD -> LND for r in self.resblocks: if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) if not self.batch_first: x = x.transpose(0, 1) # NLD -> LND return x class VisionTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, ls_init_value: float = None, attentional_pool: bool = False, attn_pooler_queries: int = 256, attn_pooler_heads: int = 8, output_dim: int = 512, patch_dropout: float = 0., no_ln_pre: bool = False, pos_embed_type: str = 'learnable', pool_type: str = 'tok', final_ln_after_pool: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): super().__init__() assert pool_type in ('tok', 'avg', 'none') self.output_tokens = output_tokens image_height, image_width = self.image_size = to_2tuple(image_size) patch_height, patch_width = self.patch_size = to_2tuple(patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # class embeddings and positional embeddings scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) if pos_embed_type == 'learnable': self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) elif pos_embed_type == 'sin_cos_2d': # fixed sin-cos embedding assert self.grid_size[0] == self.grid_size[1],\ 'currently sin cos 2d pos embedding only supports square input' self.positional_embedding = nn.Parameter( torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) else: raise ValueError # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) self.transformer = Transformer( width, layers, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) if attentional_pool: if isinstance(attentional_pool, str): self.attn_pool_type = attentional_pool self.pool_type = 'none' if attentional_pool in ('parallel', 'cascade'): self.attn_pool = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=attn_pooler_queries, ) self.attn_pool_contrastive = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=1, ) else: assert False else: self.attn_pool_type = '' self.pool_type = pool_type self.attn_pool = AttentionalPooler( output_dim, width, n_head=attn_pooler_heads, n_queries=attn_pooler_queries, ) self.attn_pool_contrastive = None pool_dim = output_dim else: self.attn_pool = None pool_dim = width self.pool_type = pool_type self.ln_post = norm_layer(pool_dim) self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) self.init_parameters() def lock(self, unlocked_groups=0, freeze_bn_stats=False): for param in self.parameters(): param.requires_grad = False if unlocked_groups != 0: groups = [ [ self.conv1, self.class_embedding, self.positional_embedding, self.ln_pre, ], *self.transformer.resblocks[:-1], [ self.transformer.resblocks[-1], self.ln_post, ], self.proj, ] def _unlock(x): if isinstance(x, Sequence): for g in x: _unlock(g) else: if isinstance(x, torch.nn.Parameter): x.requires_grad = True else: for p in x.parameters(): p.requires_grad = True _unlock(groups[-unlocked_groups:]) def init_parameters(self): # FIXME OpenAI CLIP did not define an init for the VisualTransformer # TODO experiment if default PyTorch init, below, or alternate init is best. # nn.init.normal_(self.class_embedding, std=self.scale) # nn.init.normal_(self.positional_embedding, std=self.scale) # # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) # attn_std = self.transformer.width ** -0.5 # fc_std = (2 * self.transformer.width) ** -0.5 # for block in self.transformer.resblocks: # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) # # if self.text_projection is not None: # nn.init.normal_(self.text_projection, std=self.scale) pass @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.pool_type == 'avg': pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] elif self.pool_type == 'tok': pooled, tokens = x[:, 0], x[:, 1:] else: pooled = tokens = x return pooled, tokens def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) # shape = [*, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.patch_dropout(x) x = self.ln_pre(x) x = self.transformer(x) if self.attn_pool is not None: if self.attn_pool_contrastive is not None: # This is untested, WIP pooling that should match paper x = self.ln_post(x) # TBD LN first or separate one after each pool? tokens = self.attn_pool(x) if self.attn_pool_type == 'parallel': pooled = self.attn_pool_contrastive(x) else: assert self.attn_pool_type == 'cascade' pooled = self.attn_pool_contrastive(tokens) else: # this is the original OpenCLIP CoCa setup, does not match paper x = self.attn_pool(x) x = self.ln_post(x) pooled, tokens = self._global_pool(x) elif self.final_ln_after_pool: pooled, tokens = self._global_pool(x) pooled = self.ln_post(pooled) else: x = self.ln_post(x) pooled, tokens = self._global_pool(x) if self.proj is not None: pooled = pooled @ self.proj if self.output_tokens: return pooled, tokens return pooled def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): if pool_type == 'first': pooled, tokens = x[:, 0], x[:, 1:] elif pool_type == 'last': pooled, tokens = x[:, -1], x[:, :-1] elif pool_type == 'argmax': # take features from the eot embedding (eot_token is the highest number in each sequence) assert text is not None pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x else: pooled = tokens = x return pooled, tokens class TextTransformer(nn.Module): output_tokens: torch.jit.Final[bool] def __init__( self, context_length: int = 77, vocab_size: int = 49408, width: int = 512, heads: int = 8, layers: int = 12, mlp_ratio: float = 4.0, ls_init_value: float = None, output_dim: int = 512, embed_cls: bool = False, no_causal_mask: bool = False, pad_id: int = 0, pool_type: str = 'argmax', proj_bias: bool = False, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_tokens: bool = False, ): super().__init__() assert pool_type in ('first', 'last', 'argmax', 'none') self.output_tokens = output_tokens self.num_pos = self.context_length = context_length self.vocab_size = vocab_size self.width = width self.output_dim = output_dim self.heads = heads self.pad_id = pad_id self.pool_type = pool_type self.token_embedding = nn.Embedding(vocab_size, width) if embed_cls: self.cls_emb = nn.Parameter(torch.empty(width)) self.num_pos += 1 else: self.cls_emb = None self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) self.transformer = Transformer( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, ) self.ln_final = norm_layer(width) if no_causal_mask: self.attn_mask = None else: self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) if proj_bias: self.text_projection = nn.Linear(width, output_dim) else: self.text_projection = nn.Parameter(torch.empty(width, output_dim)) self.init_parameters() def init_parameters(self): nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) if self.cls_emb is not None: nn.init.normal_(self.cls_emb, std=0.01) proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) if self.text_projection.bias is not None: nn.init.zeros_(self.text_projection.bias) else: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.transformer.grad_checkpointing = enable def build_causal_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.num_pos, self.num_pos) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def build_cls_mask(self, text, cast_dtype: torch.dtype): cls_mask = (text != self.pad_id).unsqueeze(1) cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) additive_mask.fill_(0) additive_mask.masked_fill_(~cls_mask, float("-inf")) additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) return additive_mask def forward(self, text): cast_dtype = self.transformer.get_cast_dtype() seq_len = text.shape[1] x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] attn_mask = self.attn_mask if self.cls_emb is not None: seq_len += 1 x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) cls_mask = self.build_cls_mask(text, cast_dtype) if attn_mask is not None: attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] x = x + self.positional_embedding[:seq_len].to(cast_dtype) x = self.transformer(x, attn_mask=attn_mask) # x.shape = [batch_size, n_ctx, transformer.width] if self.cls_emb is not None: # presence of appended cls embed (CoCa) overrides pool_type, always take last token pooled, tokens = text_global_pool(x, pool_type='last') pooled = self.ln_final(pooled) # final LN applied after pooling in this case else: x = self.ln_final(x) pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) if self.text_projection is not None: if isinstance(self.text_projection, nn.Linear): pooled = self.text_projection(pooled) else: pooled = pooled @ self.text_projection if self.output_tokens: return pooled, tokens return pooled class MultimodalTransformer(Transformer): def __init__( self, width: int, layers: int, heads: int, context_length: int = 77, mlp_ratio: float = 4.0, ls_init_value: float = None, act_layer: Callable = nn.GELU, norm_layer: Callable = LayerNorm, output_dim: int = 512, batch_first: bool = True, ): super().__init__( width=width, layers=layers, heads=heads, mlp_ratio=mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, batch_first=batch_first, ) self.context_length = context_length self.cross_attn = nn.ModuleList([ ResidualAttentionBlock( width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, is_cross_attention=True, batch_first=batch_first, ) for _ in range(layers) ]) self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) self.ln_final = norm_layer(width) self.text_projection = nn.Parameter(torch.empty(width, output_dim)) def init_parameters(self): proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) attn_std = self.transformer.width ** -0.5 fc_std = (2 * self.transformer.width) ** -0.5 for block in self.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) for block in self.transformer.cross_attn: nn.init.normal_(block.attn.in_proj_weight, std=attn_std) nn.init.normal_(block.attn.out_proj.weight, std=proj_std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) if self.text_projection is not None: nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) def build_attention_mask(self): # lazily create causal attention mask, with full attention between the tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(self.context_length, self.context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal return mask def forward(self, image_embs, text_embs): seq_len = text_embs.shape[1] if not self.batch_first: image_embs = image_embs.permute(1, 0, 2) # NLD -> LND text_embs = text_embs.permute(1, 0, 2) # NLD -> LND for resblock, cross_attn in zip(self.resblocks, self.cross_attn): if self.grad_checkpointing and not torch.jit.is_scripting(): # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) else: text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) if not self.batch_first: text_embs = text_embs.permute(1, 0, 2) # LND -> NLD out = self.ln_final(text_embs) if self.text_projection is not None: out = out @ self.text_projection return out @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable ================================================ FILE: inf_clip/openai.py ================================================ """ OpenAI pretrained model functions Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. """ import os import warnings from typing import List, Optional, Union import torch from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url from .models.clip_arch import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype __all__ = ["list_openai_models", "load_openai_model"] def list_openai_models() -> List[str]: """Returns the names of available CLIP models""" return list_pretrained_models_by_tag('openai') def load_openai_model( name: str, precision: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, cache_dir: Optional[str] = None, ): """Load a CLIP model Parameters ---------- name : str A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict precision: str Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. device : Union[str, torch.device] The device to put the loaded model cache_dir : Optional[str] The directory to cache the downloaded model weights Returns ------- model : torch.nn.Module The CLIP model preprocess : Callable[[PIL.Image], torch.Tensor] A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if precision is None: precision = 'fp32' if device == 'cpu' else 'fp16' if get_pretrained_url(name, 'openai'): model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) elif os.path.isfile(name): model_path = name else: raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: # loading saved state dict state_dict = torch.load(model_path, map_location="cpu") # Build a non-jit model from the OpenAI jitted model state dict cast_dtype = get_cast_dtype(precision) try: model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) except KeyError: sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) # FIXME support pure fp16/bf16 precision modes if precision != 'fp16': model.float() if precision == 'bf16': # for bf16, convert back to low-precision convert_weights_to_lp(model, dtype=torch.bfloat16) # add mean / std attributes for consistency with OpenCLIP models model.visual.image_mean = OPENAI_DATASET_MEAN model.visual.image_std = OPENAI_DATASET_STD return model ================================================ FILE: inf_clip/pretrained.py ================================================ import hashlib import os import urllib import warnings from functools import partial from typing import Dict, Union import torch import numpy as np from tqdm import tqdm from .models.clip_arch import CLIP, CustomTextCLIP from .models.transformer import TextTransformer, Transformer from .constants import ( IMAGENET_MEAN, IMAGENET_STD, INCEPTION_MEAN, INCEPTION_STD, OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, ) __version__ = "2.26.1" try: from huggingface_hub import hf_hub_download hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) _has_hf_hub = True except ImportError: hf_hub_download = None _has_hf_hub = False def _pcfg(url='', hf_hub='', **kwargs): # OpenAI / OpenCLIP defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': OPENAI_DATASET_MEAN, 'std': OPENAI_DATASET_STD, 'interpolation': 'bicubic', 'resize_mode': 'shortest', **kwargs, } def _slpcfg(url='', hf_hub='', **kwargs): # SiGLIP defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': INCEPTION_MEAN, 'std': INCEPTION_STD, 'interpolation': 'bicubic', 'resize_mode': 'squash', **kwargs, } def _apcfg(url='', hf_hub='', **kwargs): # CLIPA defaults return { 'url': url, 'hf_hub': hf_hub, 'mean': IMAGENET_MEAN, 'std': IMAGENET_STD, 'interpolation': 'bilinear', 'resize_mode': 'squash', **kwargs, } def _mccfg(url='', hf_hub='', **kwargs): # MobileCLIP return { 'url': url, 'hf_hub': hf_hub, 'mean': (0., 0., 0.), 'std': (1., 1., 1.), 'interpolation': 'bilinear', 'resize_mode': 'shortest', **kwargs, } _RN50 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), cc12m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), ) _RN50_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), cc12m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), ) _RN101 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), ) _RN101_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), yfcc15m=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), ) _RN50x4 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), ) _RN50x16 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), ) _RN50x64 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), ) _VITB32 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), laion2b_e16=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), # DataComp-M models datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), # DataComp-S models datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), ) _VITB32_quickgelu = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), metaclip_400m=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), metaclip_fullcc=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), ) _VITB32_256 = dict( datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), ) _VITB16 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), # DataComp-L models datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), # DFN dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') ) _VITB16_quickgelu = dict( metaclip_400m=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), metaclip_fullcc=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), ) _VITL14 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), laion400m_e31=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), laion400m_e32=_pcfg( "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=INCEPTION_MEAN, std=INCEPTION_STD), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), ) _VITL14_quickgelu = dict( metaclip_400m=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), metaclip_fullcc=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), ) _VITL14_336 = dict( openai=_pcfg( "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), ) _VITH14 = dict( laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), ) _VITH14_quickgelu = dict( metaclip_fullcc=_pcfg( "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14/', interpolation="bicubic", resize_mode="squash" ), ) _VITH14_378_quickgelu = dict( dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', interpolation="bicubic", resize_mode="squash" ), ) _VITg14 = dict( laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), ) _VITbigG14 = dict( laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), ) _robertaViTB32 = dict( laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), ) _xlmRobertaBaseViTB32 = dict( laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), ) _xlmRobertaLargeFrozenViTH14 = dict( frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), ) _convnext_base = dict( laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), ) _convnext_base_w = dict( laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), ) _convnext_base_w_320 = dict( laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), ) _convnext_large_d = dict( laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), ) _convnext_large_d_320 = dict( laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), ) _convnext_xxlarge = dict( laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), ) _coca_VITB32 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') ) _coca_VITL14 = dict( laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') ) _PRETRAINED = { "RN50": _RN50, "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, "ViT-B-16-quickgelu": _VITB16_quickgelu, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, "ViT-L-14-quickgelu": _VITL14_quickgelu, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, "ViT-H-14-quickgelu": _VITH14_quickgelu, "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, "convnext_base": _convnext_base, "convnext_base_w": _convnext_base_w, "convnext_base_w_320": _convnext_base_w_320, "convnext_large_d": _convnext_large_d, "convnext_large_d_320": _convnext_large_d_320, "convnext_xxlarge": _convnext_xxlarge, "coca_ViT-B-32": _coca_VITB32, "coca_ViT-L-14": _coca_VITL14, "EVA01-g-14": dict( # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), ), "EVA01-g-14-plus": dict( # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), ), "EVA02-B-16": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), ), "EVA02-L-14": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), ), "EVA02-L-14-336": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), ), "EVA02-E-14": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), ), "EVA02-E-14-plus": dict( # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), ), "ViT-B-16-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), ), "ViT-B-16-SigLIP-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), ), "ViT-B-16-SigLIP-i18n-256": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), ), "ViT-B-16-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), ), "ViT-B-16-SigLIP-512": dict( webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), ), "ViT-L-16-SigLIP-256": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), ), "ViT-L-16-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), ), "ViT-SO400M-14-SigLIP": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), ), "ViT-SO400M-14-SigLIP-384": dict( webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), ), "ViT-L-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), ), "ViT-L-14-CLIPA-336": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), ), "ViT-H-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), ), "ViT-H-14-CLIPA-336": dict( laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), ), "ViT-bigG-14-CLIPA": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), ), "ViT-bigG-14-CLIPA-336": dict( datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), ), "nllb-clip-base": dict( v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), ), "nllb-clip-large": dict( v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), ), "nllb-clip-base-siglip": dict( v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), ), "nllb-clip-large-siglip": dict( v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), ), "MobileCLIP-S1": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), "MobileCLIP-S2": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), "MobileCLIP-B": dict( datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), ), "ViTamin-S": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), ), "ViTamin-S-LTT": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), ), "ViTamin-B": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), ), "ViTamin-B-LTT": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), ), "ViTamin-L": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), ), "ViTamin-L-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), ), "ViTamin-L-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), ), "ViTamin-L-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), ), "ViTamin-L2": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), ), "ViTamin-L2-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), ), "ViTamin-L2-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), ), "ViTamin-L2-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), ), "ViTamin-XL-256": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), ), "ViTamin-XL-336": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), ), "ViTamin-XL-384": dict( datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), ), } def _clean_tag(tag: str): # normalize pretrained tags return tag.lower().replace('-', '_') def list_pretrained(as_str: bool = False): """ returns list of pretrained models Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True """ return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] def list_pretrained_models_by_tag(tag: str): """ return all models having the specified pretrain tag """ models = [] tag = _clean_tag(tag) for k in _PRETRAINED.keys(): if tag in _PRETRAINED[k]: models.append(k) return models def list_pretrained_tags_by_model(model: str): """ return all pretrain tags for the specified model architecture """ tags = [] if model in _PRETRAINED: tags.extend(_PRETRAINED[model].keys()) return tags def is_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return False return _clean_tag(tag) in _PRETRAINED[model] def get_pretrained_cfg(model: str, tag: str): if model not in _PRETRAINED: return {} model_pretrained = _PRETRAINED[model] return model_pretrained.get(_clean_tag(tag), {}) def get_pretrained_url(model: str, tag: str): cfg = get_pretrained_cfg(model, _clean_tag(tag)) return cfg.get('url', '') def download_pretrained_from_url( url: str, cache_dir: Union[str, None] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") os.makedirs(cache_dir, exist_ok=True) filename = os.path.basename(url) if 'openaipublic' in url: expected_sha256 = url.split("/")[-2] elif 'mlfoundations' in url: expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] else: expected_sha256 = '' download_target = os.path.join(cache_dir, filename) if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") if os.path.isfile(download_target): if expected_sha256: if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): return download_target else: warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") else: return download_target with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: while True: buffer = source.read(8192) if not buffer: break output.write(buffer) loop.update(len(buffer)) if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") return download_target def has_hf_hub(necessary=False): if not _has_hf_hub and necessary: # if no HF Hub module installed, and it is necessary to continue, raise error raise RuntimeError( 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') return _has_hf_hub def download_pretrained_from_hf( model_id: str, filename: str = 'open_clip_pytorch_model.bin', revision=None, cache_dir: Union[str, None] = None, ): has_hf_hub(True) cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) return cached_file def download_pretrained( cfg: Dict, force_hf_hub: bool = False, cache_dir: Union[str, None] = None, ): target = '' if not cfg: return target download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') if download_hf_hub and force_hf_hub: # use HF hub even if url exists download_url = '' if download_url: target = download_pretrained_from_url(download_url, cache_dir=cache_dir) elif download_hf_hub: has_hf_hub(True) # we assume the hf_hub entries in pretrained config combine model_id + filename in # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. model_id, filename = os.path.split(download_hf_hub) if filename: target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) else: target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) return target @torch.no_grad() def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): """ Load weights from .npz checkpoints for official Google big_vision image-text models Currently the SigLIP source models are supported and a CustomTextCLIP destination model w/ timm image encoder. """ from timm.layers import resample_patch_embed, resample_abs_pos_embed def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) interpolation = 'bilinear' antialias = False def _convert_timm_img(module, prefix): embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: embed_conv_w = resample_patch_embed( embed_conv_w, module.patch_embed.proj.weight.shape[-2:], interpolation=interpolation, antialias=antialias, verbose=True, ) module.patch_embed.proj.weight.copy_(embed_conv_w) module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) if module.cls_token is not None: module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) if pos_embed_w.shape != module.pos_embed.shape: assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, new_size=module.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) module.pos_embed.copy_(pos_embed_w) mha_sub, b_sub, ln1_sub = (0, 0, 1) for i, block in enumerate(module.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.qkv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.qkv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) for r in range(2): getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) if module.attn_pool is not None: block_prefix = f'{prefix}MAPHead_0/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) module.attn_pool.kv.weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) module.attn_pool.kv.bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) for r in range(2): getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) def _convert_openclip_transformer(module: Transformer, prefix): for i, block in enumerate(module.resblocks.children()): block_prefix = f'{prefix}encoderblock_{i}/' mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.attn.in_proj_weight.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) block.attn.in_proj_bias.copy_(torch.cat([ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) def _convert_openclip_txt(module: TextTransformer, prefix): module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) module.positional_embedding.copy_(pos_embed_w) _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) _convert_timm_img(model.visual.trunk, 'params/img/') _convert_openclip_txt(model.text, 'params/txt/') model.logit_bias.copy_(_n2p(w['params/b'])[0]) model.logit_scale.copy_(_n2p(w['params/t'])[0]) @torch.no_grad() def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): def _convert_timm_img(state_dict): if fastvit: from timm.models.fastvit import checkpoint_filter_fn else: from timm.models.vision_transformer_hybrid import checkpoint_filter_fn timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} return timm_state_dict def _convert_openclip_txt(state_dict, prefix='text_encoder.'): text_dict = {} for k, v in state_dict.items(): if not k.startswith(prefix): continue k = k.replace(prefix, '') k = k.replace('projection_layer', 'text_projection') k = k.replace('embedding_layer', 'token_embedding') if k.startswith('positional_embedding.pos_embed.pos_embed'): k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') v = v.squeeze() k = k.replace('final_layer_norm', 'ln_final') k = k.replace('pre_norm_mha.0', 'ln_1') k = k.replace('pre_norm_mha.1', 'attn') k = k.replace('pre_norm_ffn.0', 'ln_2') k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') k = k.replace('qkv_proj.weight', 'in_proj_weight') k = k.replace('qkv_proj.bias', 'in_proj_bias') k = k.replace('transformer.', 'transformer.resblocks.') text_dict['text.' + k] = v return text_dict image_dict = _convert_timm_img(state_dict) text_dict = _convert_openclip_txt(state_dict) out_dict = {**image_dict, **text_dict} out_dict['logit_scale'] = state_dict['logit_scale'] return out_dict def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) state_dict = convert_mobile_clip_state_dict(model, state_dict) if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: # convert b model state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) return state_dict ================================================ FILE: inf_clip/train/data.py ================================================ import ast import json import logging import math import os import random import sys import braceexpand from dataclasses import dataclass from multiprocessing import Value import numpy as np import pandas as pd import torch import torchvision.datasets as datasets import webdataset as wds from PIL import Image from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info from torch.utils.data.distributed import DistributedSampler from webdataset.filters import _shuffle from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample class CsvDataset(Dataset): def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", tokenizer=None): logging.debug(f'Loading csv data from {input_filename}.') df = pd.read_csv(input_filename, sep=sep) self.images = df[img_key].tolist() self.captions = df[caption_key].tolist() self.transforms = transforms logging.debug('Done loading data.') self.tokenize = tokenizer def __len__(self): return len(self.captions) def __getitem__(self, idx): images = self.transforms(Image.open(str(self.images[idx]))) texts = self.tokenize([str(self.captions[idx])])[0] return images, texts class SharedEpoch: def __init__(self, epoch: int = 0): self.shared_epoch = Value('i', epoch) def set_value(self, epoch): self.shared_epoch.value = epoch def get_value(self): return self.shared_epoch.value @dataclass class DataInfo: dataloader: DataLoader sampler: DistributedSampler = None shared_epoch: SharedEpoch = None def set_epoch(self, epoch): if self.shared_epoch is not None: self.shared_epoch.set_value(epoch) if self.sampler is not None and isinstance(self.sampler, DistributedSampler): self.sampler.set_epoch(epoch) def expand_urls(urls, weights=None): if weights is None: expanded_urls = wds.shardlists.expand_urls(urls) return expanded_urls, None if isinstance(urls, str): urllist = urls.split("::") weights = weights.split('::') assert len(weights) == len(urllist),\ f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." weights = [float(weight) for weight in weights] all_urls, all_weights = [], [] for url, weight in zip(urllist, weights): expanded_url = list(braceexpand.braceexpand(url)) expanded_weights = [weight for _ in expanded_url] all_urls.extend(expanded_url) all_weights.extend(expanded_weights) return all_urls, all_weights else: all_urls = list(urls) return all_urls, weights def get_dataset_size(shards): shards_list, _ = expand_urls(shards) dir_path = os.path.dirname(shards_list[0]) sizes_filename = os.path.join(dir_path, 'sizes.json') len_filename = os.path.join(dir_path, '__len__') if os.path.exists(sizes_filename): sizes = json.load(open(sizes_filename, 'r')) total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) elif os.path.exists(len_filename): # FIXME this used to be eval(open(...)) but that seemed rather unsafe total_size = ast.literal_eval(open(len_filename, 'r').read()) else: total_size = None # num samples undefined # some common dataset sizes (at time of authors last download) # CC3M (train): 2905954 # CC12M: 10968539 # LAION-400M: 407332084 # LAION-2B (english): 2170337258 num_shards = len(shards_list) return total_size, num_shards def get_imagenet(args, preprocess_fns, split): assert split in ["train", "val", "v2"] is_train = split == "train" preprocess_train, preprocess_val = preprocess_fns if split == "v2": from imagenetv2_pytorch import ImageNetV2Dataset dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) else: if is_train: data_path = args.imagenet_train preprocess_fn = preprocess_train else: data_path = args.imagenet_val preprocess_fn = preprocess_val assert data_path dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) if is_train: idxs = np.zeros(len(dataset.targets)) target_array = np.array(dataset.targets) k = 50 for c in range(1000): m = target_array == c n = len(idxs[m]) arr = np.zeros(n) arr[:k] = 1 np.random.shuffle(arr) idxs[m] = arr idxs = idxs.astype('int') sampler = SubsetRandomSampler(np.where(idxs)[0]) else: sampler = None dataloader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, ) return DataInfo(dataloader=dataloader, sampler=sampler) def count_samples(dataloader): os.environ["WDS_EPOCH"] = "0" n_elements, n_batches = 0, 0 for images, texts in dataloader: n_batches += 1 n_elements += len(images) assert len(images) == len(texts) return n_elements, n_batches def filter_no_caption_or_no_image(sample): has_caption = ('txt' in sample or 'json' in sample) has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample) return has_caption and has_image def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') return True def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to lower case (Default value = True) """ current_sample = None for filesample in data: assert isinstance(filesample, dict) # FIXME this is a bit of a hack to handle the fact that the CC3M/LAION400m dataset has some empty files. try: fname, value = filesample["fname"], filesample["data"] except KeyError as exn: continue prefix, suffix = keys(fname) if prefix is None: continue if lcase: suffix = suffix.lower() # FIXME webdataset version throws if suffix in current_sample, but we have a potential for # this happening in the current LAION400m dataset if a tar ends with same prefix as the next # begins, rare, but can happen since prefix aren't unique across tar files in that dataset if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: if valid_sample(current_sample): yield current_sample current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) if suffixes is None or suffix in suffixes: current_sample[suffix] = value if valid_sample(current_sample): yield current_sample def tarfile_to_samples_nothrow(src, handler=log_and_continue): # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw streams = url_opener(src, handler=handler) files = tar_file_expander(streams, handler=handler) samples = group_by_keys_nothrow(files, handler=handler) return samples def pytorch_worker_seed(increment=0): """get dataloader worker seed from pytorch""" worker_info = get_worker_info() if worker_info is not None: # favour using the seed already created for pytorch dataloader workers if it exists seed = worker_info.seed if increment: # space out seed increments so they can't overlap across workers in different iterations seed += increment * max(1, worker_info.num_workers) return seed # fallback to wds rank based seed return wds.utils.pytorch_worker_seed() def json_fetch(data, key='caption'): for sample in data: if 'json' in sample: if isinstance(sample['json'], dict): value = sample['json'].get(key, None) else: value = sample['json'] else: value = sample['txt'] if isinstance(value, str): sample['txt'] = [value] elif isinstance(value, list): sample['txt'] = value else: # print(f"Expected {key} to be a string or list of strings, got {type(value)} {sample}") continue yield sample _SHARD_SHUFFLE_SIZE = 2000 _SHARD_SHUFFLE_INITIAL = 500 _SAMPLE_SHUFFLE_SIZE = 5000 _SAMPLE_SHUFFLE_INITIAL = 1000 class detshuffle2(wds.PipelineStage): def __init__( self, bufsize=1000, initial=100, seed=0, epoch=-1, ): self.bufsize = bufsize self.initial = initial self.seed = seed self.epoch = epoch def run(self, src): if isinstance(self.epoch, SharedEpoch): epoch = self.epoch.get_value() else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch rng = random.Random() if self.seed < 0: # If seed is negative, we use the worker's seed, this will be different across all nodes/workers seed = pytorch_worker_seed(epoch) else: # This seed to be deterministic AND the same across all nodes/workers in each epoch seed = self.seed + epoch rng.seed(seed) return _shuffle(src, self.bufsize, self.initial, rng) class ResampledShards2(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( self, urls, weights=None, nshards=sys.maxsize, worker_seed=None, deterministic=False, epoch=-1, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string """ super().__init__() urls, weights = expand_urls(urls, weights) self.urls = urls self.weights = weights if self.weights is not None: assert len(self.urls) == len(self.weights),\ f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match." assert isinstance(self.urls[0], str) self.nshards = nshards self.rng = random.Random() self.worker_seed = worker_seed self.deterministic = deterministic self.epoch = epoch def __iter__(self): """Return an iterator over the shards.""" if isinstance(self.epoch, SharedEpoch): epoch = self.epoch.get_value() else: # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) # situation as different workers may wrap at different times (or not at all). self.epoch += 1 epoch = self.epoch if self.deterministic: # reset seed w/ epoch if deterministic if self.worker_seed is None: # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id seed = pytorch_worker_seed(epoch) else: seed = self.worker_seed() + epoch self.rng.seed(seed) for _ in range(self.nshards): if self.weights is None: yield dict(url=self.rng.choice(self.urls)) else: yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0]) def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None): input_shards = args.train_data if is_train else args.val_data assert input_shards is not None resampled = getattr(args, 'dataset_resampled', False) and is_train num_shards = None if is_train: if args.train_num_samples is not None: num_samples = args.train_num_samples else: num_samples, num_shards = get_dataset_size(input_shards) if not num_samples: raise RuntimeError( 'Currently, the number of dataset samples must be specified for the training dataset. ' 'Please specify it via `--train-num-samples` if no dataset length info is present.') else: # Eval will just exhaust the iterator if the size is not specified. num_samples = args.val_num_samples or 0 shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc if is_train and args.train_data_upsampling_factors is not None: assert resampled, "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)." if resampled: pipeline = [ResampledShards2( input_shards, weights=args.train_data_upsampling_factors, deterministic=True, epoch=shared_epoch, )] else: pipeline = [wds.SimpleShardList(input_shards)] # at this point we have an iterator over all the shards if is_train: if not resampled: pipeline.extend([ detshuffle2( bufsize=_SHARD_SHUFFLE_SIZE, initial=_SHARD_SHUFFLE_INITIAL, seed=args.seed, epoch=shared_epoch, ), wds.split_by_node, wds.split_by_worker, ]) pipeline.extend([ # at this point, we have an iterator over the shards assigned to each worker at each node tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue), wds.shuffle( bufsize=_SAMPLE_SHUFFLE_SIZE, initial=_SAMPLE_SHUFFLE_INITIAL, ), ]) else: pipeline.extend([ wds.split_by_worker, # at this point, we have an iterator over the shards assigned to each worker wds.tarfile_to_samples(handler=log_and_continue), ]) pipeline.extend([ wds.select(filter_no_caption_or_no_image), wds.decode("pilrgb", handler=log_and_continue), json_fetch, wds.rename(image="jpg;png;jpeg;webp", text="txt;json"), wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]), wds.to_tuple("image", "text"), wds.batched(args.batch_size, partial=not is_train) ]) dataset = wds.DataPipeline(*pipeline) if is_train: if not resampled: num_shards = num_shards or len(expand_urls(input_shards)[0]) assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers' # roll over and repeat a few samples to get same number of full batches on each node round_fn = math.floor if floor else math.ceil global_batch_size = args.batch_size * args.world_size num_batches = round_fn(num_samples / global_batch_size) num_workers = max(1, args.workers) num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker num_batches = num_worker_batches * num_workers num_samples = num_batches * global_batch_size dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this else: # last batches are partial, eval is done on single (master) node num_batches = math.ceil(num_samples / args.batch_size) dataloader = wds.WebLoader( dataset, batch_size=None, shuffle=False, num_workers=args.workers, persistent_workers=args.workers > 0, ) # FIXME not clear which approach is better, with_epoch before vs after dataloader? # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 # if is_train: # # roll over and repeat a few samples to get same number of full batches on each node # global_batch_size = args.batch_size * args.world_size # num_batches = math.ceil(num_samples / global_batch_size) # num_workers = max(1, args.workers) # num_batches = math.ceil(num_batches / num_workers) * num_workers # num_samples = num_batches * global_batch_size # dataloader = dataloader.with_epoch(num_batches) # else: # # last batches are partial, eval is done on single (master) node # num_batches = math.ceil(num_samples / args.batch_size) # add meta-data to dataloader instance for convenience dataloader.num_batches = num_batches dataloader.num_samples = num_samples dataloader.batch_size = args.batch_size return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch) def get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): input_filename = args.train_data if is_train else args.val_data assert input_filename dataset = CsvDataset( input_filename, preprocess_fn, img_key=args.csv_img_key, caption_key=args.csv_caption_key, sep=args.csv_separator, tokenizer=tokenizer ) num_samples = len(dataset) sampler = DistributedSampler(dataset) if args.distributed and is_train else None shuffle = is_train and sampler is None dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=shuffle, num_workers=args.workers, pin_memory=True, sampler=sampler, drop_last=is_train, ) dataloader.num_samples = num_samples dataloader.num_batches = len(dataloader) return DataInfo(dataloader, sampler) class SyntheticDataset(Dataset): def __init__( self, transform=None, image_size=(224, 224), caption="Dummy caption", dataset_size=100, tokenizer=None, ): self.transform = transform self.image_size = image_size self.caption = caption self.image = Image.new('RGB', image_size) self.dataset_size = dataset_size self.preprocess_txt = lambda text: tokenizer(text)[0] def __len__(self): return self.dataset_size def __getitem__(self, idx): if self.transform is not None: image = self.transform(self.image) return image, self.preprocess_txt(self.caption) def get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None): image_size = preprocess_fn.transforms[0].size dataset = SyntheticDataset( transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer) num_samples = len(dataset) sampler = DistributedSampler(dataset) if args.distributed and is_train else None shuffle = is_train and sampler is None dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=shuffle, num_workers=args.workers, pin_memory=True, sampler=sampler, drop_last=is_train, ) dataloader.num_samples = num_samples dataloader.num_batches = len(dataloader) return DataInfo(dataloader, sampler) def get_dataset_fn(data_path, dataset_type): if dataset_type == "webdataset": return get_wds_dataset elif dataset_type == "csv": return get_csv_dataset elif dataset_type == "synthetic": return get_synthetic_dataset elif dataset_type == "auto": ext = data_path.split('.')[-1] if ext in ['csv', 'tsv']: return get_csv_dataset elif ext in ['tar']: return get_wds_dataset else: raise ValueError( f"Tried to figure out dataset type, but failed for extension {ext}.") else: raise ValueError(f"Unsupported dataset type: {dataset_type}") def get_data(args, preprocess_fns, epoch=0, tokenizer=None): preprocess_train, preprocess_val = preprocess_fns data = {} if args.train_data or args.dataset_type == "synthetic": data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer) if args.val_data: data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( args, preprocess_val, is_train=False, tokenizer=tokenizer) if args.imagenet_val is not None: data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val") if args.imagenet_v2 is not None: data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2") return data ================================================ FILE: inf_clip/train/engine.py ================================================ import json import logging import math import os import time from contextlib import nullcontext import numpy as np import pynvml import torch import torch.nn.functional as F import torch.utils.checkpoint as torch_checkpoint import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel from tqdm import tqdm from .utils import get_autocast, is_master from inf_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, CLIP, CustomTextCLIP from inf_clip.models.loss import ClipLoss try: import wandb except ImportError: wandb = None def accuracy(output, target, topk=(1,)): pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] def get_clip_metrics(image_features, text_features, logit_scale): metrics = {} logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() logits_per_text = logits_per_image.t().detach().cpu() logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} ground_truth = torch.arange(len(text_features)).view(-1, 1) for name, logit in logits.items(): ranking = torch.argsort(logit, descending=True) preds = torch.where(ranking == ground_truth)[1] preds = preds.detach().cpu().numpy() metrics[f"{name}_mean_rank"] = preds.mean() + 1 metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 for k in [1, 5, 10]: metrics[f"{name}_R@{k}"] = np.mean(preds < k) return metrics def maybe_compute_generative_loss(model_out): if "logits" in model_out and "labels" in model_out: token_logits = model_out["logits"] token_labels = model_out["labels"] return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels) def get_memory(): pynvml.nvmlInit() # NOTE: 0 denotes GPU index. handle = pynvml.nvmlDeviceGetHandleByIndex(0) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) return meminfo.used / 1024**3 def seconds_to_hms(seconds): hours, remainder = divmod(seconds, 3600) minutes, seconds = divmod(remainder, 60) hours = int(hours); minutes = int(minutes); seconds = int(seconds) return f"{hours}:{minutes:02d}:{seconds:02d}" def cal_grad_norm(model): total_norm = 0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 return total_norm def assign_learning_rate(optimizer, new_lr): for param_group in optimizer.param_groups: param_group["lr"] = new_lr def _warmup_lr(base_lr, warmup_length, step): return base_lr * (step + 1) / warmup_length def const_lr(optimizer, base_lr, warmup_length, steps): def _lr_adjuster(step): if step < warmup_length: lr = _warmup_lr(base_lr, warmup_length, step) else: lr = base_lr assign_learning_rate(optimizer, lr) return lr return _lr_adjuster def const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.): def _lr_adjuster(step): start_cooldown_step = steps - cooldown_steps if step < warmup_length: lr = _warmup_lr(base_lr, warmup_length, step) else: if step < start_cooldown_step: lr = base_lr else: e = step - start_cooldown_step es = steps - start_cooldown_step # linear decay if power == 1; polynomial decay otherwise; decay = (1 - (e/es)) ** cooldown_power lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr assign_learning_rate(optimizer, lr) return lr return _lr_adjuster def cosine_lr(optimizer, base_lr, warmup_length, steps): def _lr_adjuster(step): if step < warmup_length: lr = _warmup_lr(base_lr, warmup_length, step) else: e = step - warmup_length es = steps - warmup_length lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr assign_learning_rate(optimizer, lr) return lr return _lr_adjuster def postprocess_clip_output(model_out): return { "image_features": model_out[0], "text_features": model_out[1], "logit_scale": model_out[2] } def unwrap_model(model): if hasattr(model, 'module'): return model.module else: return model def backward(total_loss, scaler): if scaler is not None: scaler.scale(total_loss).backward() else: total_loss.backward() class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class GradientAccum: def __init__(self, model, loss, scaler, autocast, input_dtype, device): self.model = model self.loss = loss self.scaler = scaler self.autocast = autocast self.input_dtype = input_dtype self.device = device self.logit_scale = unwrap_model(model).logit_scale self.arch_type = unwrap_model(model).arch_type self.accum_freq = 0 self.accum_cpu_states = [] self.accum_gpu_devices_states = [] self.accum_images = [] self.accum_texts = [] self.accum_image_features = [] self.accum_text_features = [] self.rank = dist.get_rank() def clear(self): self.accum_image_features.clear() self.accum_text_features.clear() torch.cuda.empty_cache() def clear_state(self): self.accum_images.clear() self.accum_texts.clear() self.accum_cpu_states.clear() self.accum_gpu_devices_states.clear() self.accum_freq = 0 @torch.no_grad() def accum_inference(self, images, texts): images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True) texts = texts.to(device=self.device, non_blocking=True) # First, cache the features without any gradient tracking. with self.autocast(): # collect rand states self.accum_cpu_states.append(torch.get_rng_state()) self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts])) model_out = self.model(images, texts) self.accum_image_features.append(model_out["image_features"].detach().clone()) self.accum_text_features.append(model_out["text_features"].detach().clone()) if self.arch_type == "lit": # lit accum_image = model_out["image_trunk_features"].detach().clone() else: accum_image = images.detach().clone() accum_text = texts.detach().clone() # offloading accum_image = accum_image.cpu() accum_text = accum_text.cpu() self.accum_images.append(accum_image) self.accum_texts.append(accum_texts) self.accum_freq += 1 def accum_forward_backward(self): accum_losses = {"loss": 0.0} for j in range(self.accum_freq): images = self.accum_images[j] texts = self.accum_texts[j] # refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235 # DDP will sync gradients across GPUs, which is no need except the last batch. sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext with torch.random.fork_rng(devices=(device,)), sync_context(): # setting random states torch.set_rng_state(self.accum_cpu_states[j]) torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j]) with autocast(): model_out = self.model(images, texts) inputs_no_accum = {} inputs_no_accum["logit_scale"] = logit_scale = model_out.pop("logit_scale") if "logit_bias" in model_out: inputs_no_accum["logit_bias"] = model_out.pop("logit_bias") inputs = {} inputs["image_features"] = torch.cat(self.accum_image_features[:j] + [model_out["image_features"]] + self.accum_image_features[j + 1:]) inputs["text_features"] = torch.cat(self.accum_text_features[:j] + [model_out["text_features"]] + self.accum_text_features[j + 1:]) losses = self.loss(**inputs, **inputs_no_accum) show_loss = losses.pop("show_loss") total_loss = sum(losses.values()) losses["loss"] = show_loss del inputs del inputs_no_accum backward(total_loss, scaler) accum_losses["loss"] += losses["loss"] accum_losses["loss"] /= accum_freq self.clear() self.clear_state() return accum_losses class GradientCache: def __init__(self, model, loss, scaler, autocast, input_dtype, device): self.model = model self.loss = loss self.scaler = scaler self.autocast = autocast self.input_dtype = input_dtype self.device = device self.logit_scale = unwrap_model(model).logit_scale self.arch_type = unwrap_model(model).arch_type self.accum_freq = 0 self.accum_cpu_states = [] self.accum_gpu_devices_states = [] self.accum_images = [] self.accum_texts = [] self.accum_image_features = [] self.accum_text_features = [] self.rank = dist.get_rank() def clear(self): self.accum_image_features.clear() self.accum_text_features.clear() torch.cuda.empty_cache() def clear_state(self): self.accum_images.clear() self.accum_texts.clear() self.accum_cpu_states.clear() self.accum_gpu_devices_states.clear() self.accum_freq = 0 def forward_backward(self, images, texts): images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True) texts = texts.to(device=self.device, non_blocking=True) with self.autocast(): model_out = self.model(image=images, text=texts) model_out.pop("image_trunk_features", None) losses = self.loss(**model_out) show_loss = losses.pop("show_loss") total_loss = sum(losses.values()) losses["loss"] = show_loss backward(total_loss, self.scaler) return losses @torch.no_grad() def accum_inference(self, images, texts): images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True) texts = texts.to(device=self.device, non_blocking=True) # First, cache the features without any gradient tracking. with self.autocast(): # collect rand states self.accum_cpu_states.append(torch.get_rng_state()) self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts])) model_out = self.model(image=images, text=texts) self.accum_image_features.append(model_out["image_features"]) self.accum_text_features.append(model_out["text_features"]) # Speed analysis of detach().clone(): https://stackoverflow.com/questions/55266154/pytorch-preferred-way-to-copy-a-tensor if self.arch_type == "lit": # lit accum_image = model_out["image_trunk_features"].detach().clone() else: accum_image = images.detach().clone() accum_text = texts.detach().clone() # offloading # accum_image = accum_image.cpu() # accum_text = accum_text.cpu() self.accum_images.append(accum_image) self.accum_texts.append(accum_text) self.accum_freq += 1 def accum_forward_backward(self): accum_qs = [x.requires_grad_() for x in self.accum_image_features]; qs = torch.cat(accum_qs, dim=0) accum_ks = [x.requires_grad_() for x in self.accum_text_features]; ks = torch.cat(accum_ks, dim=0) ls = self.logit_scale.exp().detach().clone().requires_grad_() losses = self.loss(image_features=qs, text_features=ks, logit_scale=ls) show_loss = losses.pop("show_loss") total_loss = sum(losses.values()) losses["loss"] = show_loss backward(total_loss, self.scaler) accum_q_grads = [q.grad for q in accum_qs] accum_k_grads = [k.grad for k in accum_ks] l_grad = ls.grad del accum_qs, accum_ks del qs, ks, ls # Clean trash memory from loss calculation or inference self.clear() for j in range(self.accum_freq): images = self.accum_images[j] texts = self.accum_texts[j] # refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235 # DDP will sync gradients across GPUs, which is no need except the last batch. sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext with torch.random.fork_rng(devices=(self.device, )), sync_context(): # setting random states torch.set_rng_state(self.accum_cpu_states[j]) torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j]) with self.autocast(): if self.arch_type == "lit": model_out = self.model(images, texts, project_only=True) else: model_out = self.model(images, texts) q = model_out["image_features"] k = model_out["text_features"] l = model_out["logit_scale"] _loss = torch.dot(q.flatten(), accum_q_grads[j].flatten()) + \ torch.dot(k.flatten(), accum_k_grads[j].flatten()) + \ l * l_grad / self.accum_freq _loss.backward() self.clear_state() return losses def train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None): device = torch.device(args.device) autocast = get_autocast(args.precision) input_dtype = get_input_dtype(args.precision) model.train() data['train'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches // args.accum_freq sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) runner = GradientCache(model, loss, scaler, autocast, input_dtype, device) rest_iters = num_batches_per_epoch * (args.epochs - epoch) losses_m = {} global_batch_time_m = AverageMeter() batch_time_m = AverageMeter() data_time_m = AverageMeter() end = time.time() for i, batch in enumerate(dataloader): i_accum = i // args.accum_freq step = num_batches_per_epoch * epoch + i_accum if not args.skip_scheduler: scheduler(step) images, texts = batch data_time_m.update(time.time() - end) optimizer.zero_grad() if args.accum_freq == 1: losses = runner.forward_backward(images, texts) else: runner.accum_inference(images, texts) # If (i + 1) % accum_freq is not zero, move on to the next batch. if ((i + 1) % args.accum_freq) > 0: # FIXME this makes data time logging unreliable when accumulating continue # Now, ready to take gradients for the last accum_freq batches. # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives. # Call backwards each time, but only step optimizer at the end. losses = runner.accum_forward_backward() if scaler is not None: if args.horovod: optimizer.synchronize() scaler.unscale_(optimizer) if args.grad_clip_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) with optimizer.skip_synchronize(): scaler.step(optimizer) else: if args.grad_clip_norm is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) scaler.step(optimizer) scaler.update() else: if args.grad_clip_norm is not None: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0) optimizer.step() # Note: we clamp to 4.6052 = ln(100), as in the original paper. with torch.no_grad(): unwrap_model(model).logit_scale.clamp_(0, math.log(100)) global_batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end) end = time.time() batch_count = i_accum + 1 if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch): batch_size = len(images) num_samples = batch_count * batch_size * args.accum_freq * args.world_size samples_per_epoch = dataloader.num_samples percent_complete = batch_count / num_batches_per_epoch # NOTE loss is coarsely sampled, just master node and per log update for key, val in losses.items(): if key not in losses_m: losses_m[key] = AverageMeter() losses_m[key].update(val.item(), batch_size) logit_scale_scalar = unwrap_model(model).logit_scale.exp().item() loss_log = " ".join( [ f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" for loss_name, loss_m in losses_m.items() ] ) samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val grad_norm = cal_grad_norm(model.module) running_time = seconds_to_hms(time.time() - start_timestamp) rest_iters = rest_iters - 1 whole_time = seconds_to_hms(time.time() - start_timestamp + rest_iters * global_batch_time_m.avg) logging.info( f"{running_time}<{whole_time} " f"Epoch: {epoch + percent_complete:.2f} " f"Data (t): {data_time_m.avg:.3f} " f"Batch (t): {batch_time_m.avg:.3f} " f"LR: {optimizer.param_groups[0]['lr']:5f} " f"Grad Norm: {grad_norm:.3f} " f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log + " " f"Memory: {get_memory():.2f}GB " ) # Save train loss / etc. Using non avg meter values as loggers have their own smoothing log_data = { "data_time": data_time_m.val, "batch_time": batch_time_m.val, "samples_per_second": samples_per_second, "samples_per_second_per_gpu": samples_per_second_per_gpu, "scale": logit_scale_scalar, "grad_norm": grad_norm, "lr": optimizer.param_groups[0]["lr"] } log_data.update({name:val.val for name,val in losses_m.items()}) log_data = {"train/" + name: val for name, val in log_data.items()} if tb_writer is not None: for name, val in log_data.items(): tb_writer.add_scalar(name, val, step) if args.wandb: assert wandb is not None, 'Please install wandb.' log_data['step'] = step # for backwards compatibility wandb.log(log_data, step=step) # resetting batch / data time meters per log window batch_time_m.reset() data_time_m.reset() # end for def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): metrics = {} if not is_master(args): return metrics device = torch.device(args.device) model.eval() zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer) metrics.update(zero_shot_metrics) autocast = get_autocast(args.precision) input_dtype = get_input_dtype(args.precision) if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)): dataloader = data['val'].dataloader num_samples = 0 samples_per_val = dataloader.num_samples # FIXME this does not scale past small eval datasets # all_image_features @ all_text_features will blow up memory and compute very quickly cumulative_loss = 0.0 cumulative_gen_loss = 0.0 all_image_features, all_text_features = [], [] with torch.inference_mode(): for i, batch in enumerate(dataloader): images, texts = batch images = images.to(device=device, dtype=input_dtype, non_blocking=True) texts = texts.to(device=device, non_blocking=True) with autocast(): model_out = model(images, texts) image_features = model_out["image_features"] text_features = model_out["text_features"] logit_scale = model_out["logit_scale"] # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly # however, system RAM is easily exceeded and compute time becomes problematic all_image_features.append(image_features.cpu()) all_text_features.append(text_features.cpu()) logit_scale = logit_scale.mean() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() batch_size = images.shape[0] labels = torch.arange(batch_size, device=device).long() total_loss = ( F.cross_entropy(logits_per_image, labels) + F.cross_entropy(logits_per_text, labels) ) / 2 gen_loss = maybe_compute_generative_loss(model_out) cumulative_loss += total_loss * batch_size num_samples += batch_size if is_master(args) and (i % 100) == 0: logging.info( f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t" f"Clip Loss: {cumulative_loss / num_samples:.6f}\t") if gen_loss is not None: cumulative_gen_loss += gen_loss * batch_size logging.info( f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t") val_metrics = get_clip_metrics( image_features=torch.cat(all_image_features), text_features=torch.cat(all_text_features), logit_scale=logit_scale.cpu(), ) loss = cumulative_loss / num_samples metrics.update( {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples} ) if gen_loss is not None: gen_loss = cumulative_gen_loss / num_samples metrics.update({"val_generative_loss": gen_loss.item()}) if not metrics: return metrics logging.info( f"Eval Epoch: {epoch} " + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) ) log_data = {"val/" + name: val for name, val in metrics.items()} if args.save_logs: if tb_writer is not None: for name, val in log_data.items(): tb_writer.add_scalar(name, val, epoch) with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: f.write(json.dumps(metrics)) f.write("\n") if args.wandb: assert wandb is not None, 'Please install wandb.' if 'train' in data: dataloader = data['train'].dataloader num_batches_per_epoch = dataloader.num_batches // args.accum_freq step = num_batches_per_epoch * epoch else: step = None log_data['epoch'] = epoch wandb.log(log_data, step=step) return metrics def zero_shot_run(model, classifier, dataloader, args): autocast = get_autocast(args.precision) input_dtype = get_input_dtype(args.precision) with torch.inference_mode(): top1, top5, n = 0., 0., 0. for images, target in tqdm(dataloader, unit_scale=args.batch_size): images = images.to(device=args.device, dtype=input_dtype) target = target.to(args.device) with autocast(): # predict output = model(image=images) image_features = output['image_features'] if isinstance(output, dict) else output[0] logits = 100. * image_features @ classifier # measure accuracy acc1, acc5 = accuracy(logits, target, topk=(1, 5)) top1 += acc1 top5 += acc5 n += images.size(0) top1 = (top1 / n) top5 = (top5 / n) return top1, top5 def zero_shot_eval(model, data, epoch, args, tokenizer=None): if 'imagenet-val' not in data and 'imagenet-v2' not in data: return {} if args.zeroshot_frequency == 0: return {} if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: return {} if args.distributed and not args.horovod: model = model.module logging.info('Starting zero-shot imagenet.') if tokenizer is None: tokenizer = get_tokenizer(args.model) logging.info('Building zero-shot classifier') autocast = get_autocast(args.precision) with autocast(): classifier = build_zero_shot_classifier( model, tokenizer=tokenizer, classnames=IMAGENET_CLASSNAMES, templates=OPENAI_IMAGENET_TEMPLATES, num_classes_per_batch=10, device=args.device, use_tqdm=True, ) logging.info('Using classifier') results = {} if 'imagenet-val' in data: top1, top5 = zero_shot_run(model, classifier, data['imagenet-val'].dataloader, args) results['imagenet-zeroshot-val-top1'] = top1 results['imagenet-zeroshot-val-top5'] = top5 if 'imagenet-v2' in data: top1, top5 = zero_shot_run(model, classifier, data['imagenet-v2'].dataloader, args) results['imagenetv2-zeroshot-val-top1'] = top1 results['imagenetv2-zeroshot-val-top5'] = top5 logging.info('Finished zero-shot imagenet.') return results ================================================ FILE: inf_clip/train/main.py ================================================ import glob import logging import os import re import subprocess import sys import random import time from functools import partial import numpy as np import torch from torch import optim from torch.cuda.amp import GradScaler try: import wandb except ImportError: wandb = None try: import torch.utils.tensorboard as tensorboard except ImportError: tensorboard = None try: import horovod.torch as hvd except ImportError: hvd = None from inf_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss from inf_clip.train.data import get_data from inf_clip.train.params import parse_args from inf_clip.train.optims import ScalingViTAdafactor, Lion from inf_clip.train.engine import cosine_lr, const_lr, const_lr_cooldown, train_one_epoch, evaluate from inf_clip.train.utils import (setup_logging, pt_load, check_exists, start_sync_process, remote_sync, is_master, init_distributed_device, broadcast_object) LATEST_CHECKPOINT_NAME = "epoch_latest.pt" def random_seed(seed=42, rank=0): random.seed(seed + rank) np.random.seed(seed + rank) torch.manual_seed(seed + rank) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def natural_key(string_): """See http://www.codinghorror.com/blog/archives/001018.html""" return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] def copy_codebase(args): from shutil import copytree, ignore_patterns new_code_path = os.path.join(args.log_dir, args.name, "code") if os.path.exists(new_code_path): print( f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." ) return -1 print(f"Copying codebase to {new_code_path}") current_code_path = os.path.realpath(__file__) for _ in range(3): current_code_path = os.path.dirname(current_code_path) copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb')) print("Done copying code.") return 1 def prepare_logging(args): # get the name of the experiments if args.name is None: from datetime import datetime # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? model_name_safe = args.model.replace('/', '-') date_str = datetime.now().strftime("%Y_%m_%d-%H_%M_%S") if args.distributed: # sync date_str from master to all ranks date_str = broadcast_object(args, date_str) args.name = '-'.join([ date_str, f"model_{model_name_safe}", f"lr_{args.lr}", f"b_{args.batch_size}", f"j_{args.workers}", f"p_{args.precision}", ]) resume_latest = args.resume == 'latest' log_base_path = os.path.join(args.log_dir, args.name) args.log_path = None if is_master(args, local=args.log_local): os.makedirs(log_base_path, exist_ok=True) log_filename = f'out-{args.rank}' if args.log_local else 'out.log' args.log_path = os.path.join(log_base_path, log_filename) # if os.path.exists(args.log_path) and not resume_latest: # print( # "Error. Experiment already exists. Use --name {} to specify a new experiment." # ) # return -1 # Setup text logger args.log_level = logging.DEBUG if args.debug else logging.INFO setup_logging(args.log_path, args.log_level) # Setup wandb, tensorboard, checkpoint logging args.wandb = 'wandb' in args.report_to or 'all' in args.report_to args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to args.checkpoint_path = os.path.join(log_base_path, "checkpoints") if is_master(args): args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else '' for dirname in [args.tensorboard_path, args.checkpoint_path]: if dirname: os.makedirs(dirname, exist_ok=True) else: args.tensorboard_path = '' if args.copy_codebase: copy_codebase(args) return log_base_path, resume_latest def get_latest_checkpoint(path: str, remote : bool): # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders if remote: result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) print(result) if result.returncode == 1: return None checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] else: checkpoints = glob.glob(path + '**/*.pt', recursive=True) if checkpoints: checkpoints = sorted(checkpoints, key=natural_key) return checkpoints[-1] return None def prepare_resuming(args): resume_from = None checkpoint_path = args.checkpoint_path # If using remote_sync, need to check the remote instead of the local checkpoints folder. if args.remote_sync is not None: checkpoint_path = os.path.join(args.remote_sync, args.name, "checkpoints") if args.save_most_recent: print('Error. Cannot use save-most-recent with remote_sync and resume latest.') return -1 if args.remote_sync_protocol != 's3': print('Error. Sync protocol not supported when using resume latest.') return -1 if is_master(args): # Checking for existing checkpoint via master rank only. It is possible for # different rank processes to see different files if a shared file-system is under # stress, however it's very difficult to fully work around such situations. if args.save_most_recent: # if --save-most-recent flag is set, look for latest at a fixed filename resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) if not os.path.exists(resume_from): # If no latest checkpoint has been saved yet, don't try to resume resume_from = None else: # otherwise, list checkpoint dir contents and pick the newest checkpoint resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) if resume_from: logging.info(f'Found latest resume checkpoint at {resume_from}.') else: logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') if args.distributed: # sync found checkpoint path to all ranks resume_from = broadcast_object(args, resume_from) args.resume = resume_from def prepare_remote_sync(args): # start the sync proces if remote-sync is not None remote_sync_process = None if is_master(args) and args.remote_sync is not None: # first make sure it works result = remote_sync( os.path.join(args.log_dir, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol ) if result: logging.info('remote sync successful.') else: logging.info('Error: remote sync failed. Exiting.') return -1 # if all looks good, start a process to do this every args.remote_sync_frequency seconds remote_sync_process = start_sync_process( args.remote_sync_frequency, os.path.join(args.log_dir, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol ) remote_sync_process.start() return remote_sync_process def prepare_model(args, device): dist_model = None args.distill = args.distill_model is not None and args.distill_pretrained is not None if args.distill: #FIXME: support distillation with grad accum. assert args.accum_freq == 1 #FIXME: support distillation with coca. assert 'coca' not in args.model.lower() if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1: # arg is nargs, single (square) image size list -> int args.force_image_size = args.force_image_size[0] random_seed(args.seed, 0) model_kwargs = {} if args.siglip: model_kwargs['init_logit_scale'] = np.log(10) # different from CLIP model_kwargs['init_logit_bias'] = -10 model, preprocess_train, preprocess_val = create_model_and_transforms( args.model, args.pretrained, precision=args.precision, device=device, jit=args.torchscript, force_quick_gelu=args.force_quick_gelu, force_custom_text=args.force_custom_text, force_patch_dropout=args.force_patch_dropout, force_image_size=args.force_image_size, image_mean=args.image_mean, image_std=args.image_std, image_interpolation=args.image_interpolation, image_resize_mode=args.image_resize_mode, # only effective for inference aug_cfg=args.aug_cfg, pretrained_image=args.pretrained_image, output_dict=True, **model_kwargs, ) if args.distill: # FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms. dist_model, _, _ = create_model_and_transforms( args.distill_model, args.distill_pretrained, device=device, precision=args.precision, output_dict=True, ) if args.use_bnb_linear is not None: print('=> using a layer from bitsandbytes.\n' ' this is an experimental feature which requires two extra pip installs\n' ' pip install bitsandbytes triton' ' please make sure to use triton 2.0.0') import bitsandbytes as bnb from open_clip.utils import replace_linear print(f'=> replacing linear layers with {args.use_bnb_linear}') linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear) replace_linear(model, linear_replacement_cls) model = model.to(device) random_seed(args.seed, args.rank) if args.trace: model = trace_model(model, batch_size=args.batch_size, device=device) if args.lock_image: # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 model.lock_image_tower( unlocked_groups=args.lock_image_unlocked_groups, freeze_bn_stats=args.lock_image_freeze_bn_stats) if args.lock_text: model.lock_text_tower( unlocked_layers=args.lock_text_unlocked_layers, freeze_layer_norm=args.lock_text_freeze_layer_norm) if args.grad_checkpointing: model.set_grad_checkpointing() if args.distributed and not args.horovod: if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) ddp_args = {} if args.ddp_static_graph: # this doesn't exist in older PyTorch, arg only added if enabled ddp_args['static_graph'] = True model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) if args.distill: dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) if is_master(args): logging.info("Model:") logging.info(f"{str(model)}") logging.info("Params:") params_file = os.path.join(args.log_dir, args.name, "params.txt") with open(params_file, "w") as f: for name in sorted(vars(args)): val = getattr(args, name) logging.info(f" {name}: {val}") f.write(f"{name}: {val}\n") tokenizer = get_tokenizer(args.model) return tokenizer, model, dist_model, preprocess_train, preprocess_val def prepare_optimizer_scaler(args, model): assert not args.trace, 'Cannot train with traced model' exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n include = lambda n, p: not exclude(n, p) named_parameters = list(model.named_parameters()) named_parameters = [(n, p) for n, p in named_parameters if p.requires_grad] gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] if args.optimizer == "adam": optimizer = optim.AdamW( [ {"params": gain_or_bias_params, "weight_decay": 0.}, {"params": rest_params, "weight_decay": args.wd}, ], lr=args.lr, betas=(args.beta1, args.beta2), eps=args.eps, ) elif args.optimizer == "adafactor": optimizer = ScalingViTAdafactor( [ {"params": gain_or_bias_params, "weight_decay": 0.}, {"params": rest_params, "weight_decay": args.wd}, ], lr=args.lr, beta1=args.beta1, beta2=args.beta2, ) elif args.optimizer == "lion": optimizer = Lion( [ {"params": gain_or_bias_params, "weight_decay": 0.}, {"params": rest_params, "weight_decay": args.wd}, ], lr=args.lr, betas=(args.beta1, args.beta2), ) # elif args.optim == "lamb": if args.horovod: optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters()) hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) scaler = GradScaler() if args.precision == "amp" else None return optimizer, scaler def prepare_scheduler(args, optimizer, num_batches): scheduler = None total_steps = (num_batches // args.accum_freq) * args.epochs if args.lr_scheduler == "cosine": scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) elif args.lr_scheduler == "const": scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps) elif args.lr_scheduler == "const-cooldown": assert args.epochs_cooldown is not None,\ "Please specify the number of cooldown epochs for this lr schedule." cooldown_steps = (num_batches // args.accum_freq) * args.epochs_cooldown scheduler = const_lr_cooldown( optimizer, args.lr, args.warmup, total_steps, cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end) else: logging.error( f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') exit(1) return scheduler def main(args): args = parse_args(args) if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than # float16 and almost as accurate as float32 # This was a default in pytorch until 1.12 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # fully initialize distributed device environment device = init_distributed_device(args) log_base_path, resume_latest = prepare_logging(args) -1 if not resume_latest else prepare_resuming(args) remote_sync_process = prepare_remote_sync(args) if args.horovod: logging.info( f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.' f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') elif args.distributed: logging.info( f'Running in distributed mode with multiple processes. Device: {args.device}.' f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') else: logging.info(f'Running with a single process. Device {args.device}.') tokenizer, model, dist_model, preprocess_train, preprocess_val = prepare_model(args, device) # create optimizer and scaler if args.train_data or args.dataset_type == "synthetic": optimizer, scaler = prepare_optimizer_scaler(args, model) # optionally resume from a checkpoint start_epoch = 0 if args.resume is not None: checkpoint = pt_load(args.resume, map_location='cpu') if 'epoch' in checkpoint: # resuming a train checkpoint w/ epoch and optimizer state start_epoch = checkpoint["epoch"] sd = checkpoint["state_dict"] if not args.distributed and next(iter(sd.items()))[0].startswith('module'): sd = {k[len('module.'):]: v for k, v in sd.items()} model.load_state_dict(sd) if optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if scaler is not None and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") else: # loading a bare (model only) checkpoint for fine-tune or evaluation model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets data = get_data( args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=tokenizer, ) assert len(data), 'At least one train or eval dataset must be specified.' # create scheduler if train if 'train' in data and optimizer is not None: scheduler = prepare_scheduler(args, optimizer, data["train"].dataloader.num_batches) # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 args.save_logs = args.log_dir and args.log_dir.lower() != 'none' and is_master(args) writer = None if args.save_logs and args.tensorboard: assert tensorboard is not None, "Please install tensorboard." writer = tensorboard.SummaryWriter(args.tensorboard_path) if args.wandb and is_master(args): assert wandb is not None, 'Please install wandb.' logging.debug('Starting wandb.') args.train_sz = data["train"].dataloader.num_samples if args.val_data is not None: args.val_sz = data["val"].dataloader.num_samples # you will have to configure this for your project! wandb.init( project=args.wandb_project_name, name=args.name, id=args.name, notes=args.wandb_notes, tags=[], resume='auto' if args.resume == "latest" else None, config=vars(args), ) if args.debug: wandb.watch(model, log='all') wandb.save(params_file) logging.debug('Finished loading wandb.') # Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models. # For compatibility, we save state_dict() of the original model, which shares the # weights without the prefix. original_model = model if args.torchcompile: logging.info('Compiling model...') model = torch.compile(original_model) if 'train' not in data: # If using int8, convert to inference mode. if args.use_bnb_linear is not None: from open_clip.utils import convert_int8_model_to_inference_mode convert_int8_model_to_inference_mode(model) # Evaluate. evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer) return loss = create_loss(args) start_timestamp = time.time() for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) completed_epoch = epoch + 1 if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')): evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer) # Saving checkpoints. if args.save_logs: checkpoint_dict = { "epoch": completed_epoch, "name": args.name, "state_dict": original_model.state_dict(), "optimizer": optimizer.state_dict(), } if scaler is not None: checkpoint_dict["scaler"] = scaler.state_dict() if completed_epoch == args.epochs or ( args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 ): torch.save( checkpoint_dict, os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), ) if args.delete_previous_checkpoint: previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") if os.path.exists(previous_checkpoint): os.remove(previous_checkpoint) if args.save_most_recent: # try not to corrupt the latest checkpoint if save fails tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) torch.save(checkpoint_dict, tmp_save_path) os.replace(tmp_save_path, latest_save_path) if args.wandb and is_master(args): wandb.finish() # run a final sync. if remote_sync_process is not None: logging.info('Final remote sync.') remote_sync_process.terminate() result = remote_sync( os.path.join(args.log_dir, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol ) if result: logging.info('Final remote sync successful.') else: logging.info('Final remote sync failed.') if __name__ == "__main__": main(sys.argv[1:]) ================================================ FILE: inf_clip/train/optims.py ================================================ import math import torch from torch import nn from torch.optim import Optimizer class ScalingViTAdafactor(Optimizer): """ Modified version of Adafactor in Transformers https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/optimization.py#L672, which refers to Paper: *Scaling Vision Transformers* https://arxiv.org/pdf/2106.04560 1. Re-introducing the first momentum in half-precision. 2. Disable scaling of learning rate relative to weight norms, a feature that is part of Adafactor. 3. Clipping the second momentum at 0.999 """ def __init__( self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=0.9, beta2=0.999, weight_decay=0.0, scale_parameter=False, relative_step=False, warmup_init=False, ): if lr is not None and relative_step: raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") if warmup_init and not relative_step: raise ValueError("`warmup_init=True` requires `relative_step=True`") defaults = { "lr": lr, "eps": eps, "clip_threshold": clip_threshold, "decay_rate": decay_rate, "beta1": beta1, "beta2": beta2, "weight_decay": weight_decay, "scale_parameter": scale_parameter, "relative_step": relative_step, "warmup_init": warmup_init, } super().__init__(params, defaults) @staticmethod def _get_lr(param_group, param_state): rel_step_sz = param_group["lr"] if param_group["relative_step"]: min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) param_scale = 1.0 if param_group["scale_parameter"]: param_scale = max(param_group["eps"][1], param_state["RMS"]) return param_scale * rel_step_sz @staticmethod def _get_options(param_group, param_shape): factored = len(param_shape) >= 2 use_first_moment = param_group["beta1"] is not None return factored, use_first_moment @staticmethod def _rms(tensor): return tensor.norm(2) / (tensor.numel() ** 0.5) @staticmethod def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): # copy from fairseq's adafactor implementation: # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() return torch.mul(r_factor, c_factor) @torch.no_grad() def step(self, closure=None): """ Performs a single optimization step Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group["params"]: if p.grad is None: continue grad = p.grad # NOTE: gradient keep in float32 # if grad.dtype in {torch.float16, torch.bfloat16}: # grad = grad.float() if grad.is_sparse: raise RuntimeError("Adafactor does not support sparse gradients.") state = self.state[p] grad_shape = grad.shape factored, use_first_moment = self._get_options(group, grad_shape) # State Initialization if len(state) == 0: state["step"] = 0 if use_first_moment: # NOTE: using bfloat16 for first momentum # Exponential moving average of gradient values state["exp_avg"] = torch.zeros_like(grad).bfloat16() if factored: state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) else: state["exp_avg_sq"] = torch.zeros_like(grad) state["RMS"] = 0 else: if use_first_moment: state["exp_avg"] = state["exp_avg"].to(grad) if factored: state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) else: state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) p_data_fp32 = p # NOTE: keep in float32 # if p.dtype in {torch.float16, torch.bfloat16}: # p_data_fp32 = p_data_fp32.float() state["step"] += 1 state["RMS"] = self._rms(p_data_fp32) lr = self._get_lr(group, state) beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) beta2t = min(beta2t, group["beta2"]) update = (grad**2) + group["eps"][0] if factored: exp_avg_sq_row = state["exp_avg_sq_row"] exp_avg_sq_col = state["exp_avg_sq_col"] exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) # Approximation of exponential moving average of square of gradient update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) update.mul_(grad) else: exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) update = exp_avg_sq.rsqrt().mul_(grad) update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) update.mul_(lr) if use_first_moment: exp_avg = state["exp_avg"] exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) update = exp_avg if group["weight_decay"] != 0: p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) p_data_fp32.add_(-update) # if p.dtype in {torch.float16, torch.bfloat16}: # p.copy_(p_data_fp32) return loss class Lion(Optimizer): """ Modified version of Lion in https://github.com/google/automl/blob/master/lion/lion_pytorch.py, which refers to Paper: *Symbolic Discovery of Optimization Algorithms* https://arxiv.org/pdf/2302.06675 """ def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): """Initialize the hyperparameters. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-4) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99)) weight_decay (float, optional): weight decay coefficient (default: 0) """ if not 0.0 <= lr: raise ValueError('Invalid learning rate: {}'.format(lr)) if not 0.0 <= betas[0] < 1.0: raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue # Perform stepweight decay p.data.mul_(1 - group['lr'] * group['weight_decay']) grad = p.grad state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p) exp_avg = state['exp_avg'] beta1, beta2 = group['betas'] # Weight update update = exp_avg * beta1 + grad * (1 - beta1) p.add_(update.sign_(), alpha=-group['lr']) # Decay the momentum running average coefficient exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) return loss ================================================ FILE: inf_clip/train/params.py ================================================ import os import argparse import ast import json from .utils import world_info_from_env def get_default_params(model_name): # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) model_name = model_name.lower() if "vit" in model_name: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} else: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} class ParseKwargs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): kw = {} for value in values: key, value = value.split('=') try: kw[key] = ast.literal_eval(value) except ValueError: kw[key] = str(value) # fallback to string (avoid need to escape on command line) setattr(namespace, self.dest, kw) def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument( "--train-data", type=str, default=None, help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", ) parser.add_argument( "--train-data-upsampling-factors", type=str, default=None, help=( "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " "By default, datapoints are sampled uniformly regardless of the dataset sizes." ) ) parser.add_argument( "--val-data", type=str, default=None, help="Path to file(s) with validation data", ) parser.add_argument( "--train-num-samples", type=int, default=None, help="Number of samples in dataset. Required for webdataset if not available in info file.", ) parser.add_argument( "--val-num-samples", type=int, default=None, help="Number of samples in dataset. Useful for webdataset if not available in info file.", ) parser.add_argument( "--dataset-type", choices=["webdataset", "csv", "synthetic", "auto"], default="auto", help="Which type of dataset to process." ) parser.add_argument( "--dataset-resampled", default=False, action="store_true", help="Whether to use sampling with replacement for webdataset shard selection." ) parser.add_argument( "--csv-separator", type=str, default="\t", help="For csv-like datasets, which separator to use." ) parser.add_argument( "--csv-img-key", type=str, default="filepath", help="For csv-like datasets, the name of the key for the image paths." ) parser.add_argument( "--csv-caption-key", type=str, default="title", help="For csv-like datasets, the name of the key for the captions." ) parser.add_argument( "--imagenet-val", type=str, default=None, help="Path to imagenet val set for conducting zero shot evaluation.", ) parser.add_argument( "--imagenet-v2", type=str, default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) parser.add_argument( "--log_dir", type=str, default="./logs/", help="Where to store tensorboard logs. Use None to avoid storing logs.", ) parser.add_argument( "--log-local", action="store_true", default=False, help="log files on local master, otherwise global master only.", ) parser.add_argument( "--name", type=str, default=None, help="Optional identifier for the experiment when storing logs. Otherwise use current time.", ) parser.add_argument( "--workers", type=int, default=4, help="Number of dataloader workers per GPU." ) parser.add_argument( "--batch-size", type=int, default=64, help="Batch size per GPU." ) parser.add_argument( "--epochs", type=int, default=32, help="Number of epochs to train for." ) parser.add_argument( "--epochs-cooldown", type=int, default=None, help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." ) parser.add_argument("--optimizer", type=str, default="adam", help="Optimizer to use.") parser.add_argument("--lr", type=float, default=None, help="Learning rate.") parser.add_argument("--beta1", type=float, default=None, help="coefficient of moving average of first moment.") parser.add_argument("--beta2", type=float, default=None, help="coefficient of moving average of second moment.") parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") parser.add_argument( "--warmup", type=int, default=10000, help="Number of steps to warmup for." ) parser.add_argument( "--use-bn-sync", default=False, action="store_true", help="Whether to use batch norm sync.") parser.add_argument( "--skip-scheduler", action="store_true", default=False, help="Use this flag to skip the learning rate decay.", ) parser.add_argument( "--lr-scheduler", type=str, default='cosine', help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", ) parser.add_argument( "--lr-cooldown-end", type=float, default=0.0, help="End learning rate for cooldown schedule. Default: 0" ) parser.add_argument( "--lr-cooldown-power", type=float, default=1.0, help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" ) parser.add_argument( "--save-frequency", type=int, default=1, help="How often to save checkpoints." ) parser.add_argument( "--save-most-recent", action="store_true", default=False, help="Always save the most recent model trained to epoch_latest.pt.", ) parser.add_argument( "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." ) parser.add_argument( "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." ) parser.add_argument( "--resume", default=None, type=str, help="path to latest checkpoint (default: none)", ) parser.add_argument( "--precision", choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"], default="amp", help="Floating point precision." ) parser.add_argument( "--model", type=str, default="RN50", help="Name of the vision backbone to use.", ) parser.add_argument( "--pretrained", default='', type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) parser.add_argument( "--pretrained-image", default=False, action='store_true', help="Load imagenet pretrained weights for image tower backbone if available.", ) parser.add_argument( "--lock-image", default=False, action='store_true', help="Lock full image tower by disabling gradients.", ) parser.add_argument( "--lock-image-unlocked-groups", type=int, default=0, help="Leave last n image tower layer groups unlocked.", ) parser.add_argument( "--lock-image-freeze-bn-stats", default=False, action='store_true', help="Freeze BatchNorm running stats in image tower for any locked layers.", ) parser.add_argument( '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') parser.add_argument( '--image-std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of dataset') parser.add_argument( '--image-interpolation', default=None, type=str, choices=['bicubic', 'bilinear', 'random'], help="Override default image resize interpolation" ) parser.add_argument( '--image-resize-mode', default=None, type=str, choices=['shortest', 'longest', 'squash'], help="Override default image resize (& crop) mode during inference" ) parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) parser.add_argument( "--grad-checkpointing", default=False, action='store_true', help="Enable gradient checkpointing.", ) parser.add_argument( "--local-loss", default=False, action="store_true", help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" ) parser.add_argument( "--gather-with-grad", default=False, action="store_true", help="enable full distributed gradient for feature gather" ) parser.add_argument( '--force-image-size', type=int, nargs='+', default=None, help='Override default image size' ) parser.add_argument( "--force-quick-gelu", default=False, action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.", ) parser.add_argument( "--force-patch-dropout", default=None, type=float, help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", ) parser.add_argument( "--force-custom-text", default=False, action='store_true', help="Force use of CustomTextCLIP model (separate text-tower).", ) parser.add_argument( "--torchscript", default=False, action='store_true', help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", ) parser.add_argument( "--torchcompile", default=False, action='store_true', help="torch.compile() the model, requires pytorch 2.0 or later.", ) parser.add_argument( "--trace", default=False, action='store_true', help="torch.jit.trace the model for inference / eval only", ) parser.add_argument( "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." ) # arguments for distributed training parser.add_argument( "--dist-url", default="env://", type=str, help="url used to set up distributed training", ) parser.add_argument( "--dist-backend", default="nccl", type=str, help="distributed backend" ) parser.add_argument( "--report-to", default='', type=str, help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" ) parser.add_argument( "--wandb-notes", default='', type=str, help="Notes if logging with wandb" ) parser.add_argument( "--wandb-project-name", type=str, default='open-clip', help="Name of the project if logging with wandb.", ) parser.add_argument( "--debug", default=False, action="store_true", help="If true, more information is logged." ) parser.add_argument( "--copy-codebase", default=False, action="store_true", help="If true, we copy the entire base on the log directory, and execute from there." ) parser.add_argument( "--horovod", default=False, action="store_true", help="Use horovod for distributed training." ) parser.add_argument( "--ddp-static-graph", default=False, action='store_true', help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) parser.add_argument( "--no-set-device-rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." ) parser.add_argument( "--seed", type=int, default=0, help="Default random seed." ) parser.add_argument( "--grad-clip-norm", type=float, default=None, help="Gradient clip." ) parser.add_argument( "--lock-text", default=False, action='store_true', help="Lock full text tower by disabling gradients.", ) parser.add_argument( "--lock-text-unlocked-layers", type=int, default=0, help="Leave last n text tower layer groups unlocked.", ) parser.add_argument( "--lock-text-freeze-layer-norm", default=False, action='store_true', help="Freeze LayerNorm running stats in text tower for any locked layers.", ) parser.add_argument( "--log-every-n-steps", type=int, default=100, help="Log every n steps to tensorboard/console/wandb.", ) parser.add_argument( "--coca-caption-loss-weight", type=float, default=2.0, help="Weight assigned to caption loss in CoCa." ) parser.add_argument( "--coca-contrastive-loss-weight", type=float, default=1.0, help="Weight assigned to contrastive loss when training CoCa." ) parser.add_argument( "--remote-sync", type=str, default=None, help="Optinoally sync with a remote path specified by this arg", ) parser.add_argument( "--remote-sync-frequency", type=int, default=300, help="How frequently to sync to a remote directly if --remote-sync is not None.", ) parser.add_argument( "--remote-sync-protocol", choices=["s3", "fsspec"], default="s3", help="How to do the remote sync backup if --remote-sync is not None.", ) parser.add_argument( "--delete-previous-checkpoint", default=False, action="store_true", help="If true, delete previous checkpoint after storing a new one." ) parser.add_argument( "--distill-model", default=None, help='Which model arch to distill from, if any.' ) parser.add_argument( "--distill-pretrained", default=None, help='Which pre-trained weights to distill from, if any.' ) parser.add_argument( "--use-bnb-linear", default=None, help='Replace the network linear layers from the bitsandbytes library. ' 'Allows int8 training/inference, etc.' ) parser.add_argument( "--siglip", default=False, action="store_true", help='Use SigLip (sigmoid) loss.' ) parser.add_argument( "--flashloss", default=False, action="store_true", help='Use flash loss.' ) parser.add_argument( "--ringloss", default=False, action="store_true", help='Use ring loss.' ) parser.add_argument( "--infloss", default=False, action="store_true", help='Use ring flash loss.' ) parser.add_argument( "--discoloss", default=False, action="store_true", help='Use disc loss.' ) try: import deepspeed parser = deepspeed.add_config_arguments(parser) parser.add_argument('--zero-stage', type=int, default=1, help='stage of ZERO') except: print("Please 'pip install deepspeed==0.8.1'") exit(0) args = parser.parse_args(args) if args.deepspeed: create_deepspeed_config(args) # If some params are not passed, we use the default values based on model name. default_params = get_default_params(args.model) for name, val in default_params.items(): if getattr(args, name) is None: setattr(args, name, val) return args def create_deepspeed_config(args): _, _, world_size = world_info_from_env() args.deepspeed_config = os.path.join(os.getcwd(), "scripts", "deepspeed_config.json") # default optimizer optim_settings = None if args.optimizer.lower() == "adamw": optim_settings = { "type": "Adam", "adam_w_mode": True, "params": { "bias_correction": True, "betas": [ args.beta1, args.beta2 ], "eps": args.eps, } } # LAMB elif args.optimizer.lower() == "lamb": # https://arxiv.org/pdf/1904.00962.pdf optim_settings = { "type": "LAMB", "params": { "bias_correction": True, "betas": [ args.beta1, args.beta2 ], "eps": args.eps, "max_coeff": 10.0, #0.3 "min_coeff": 0.01, "eps_inside_sqrt": False, } } if args.optimizer.lower() == "1bitlamb": # not supported # 1bit-Lamb is not compatible with ZeRO; zero-stage should be 0 # https://arxiv.org/abs/2104.06069 optim_settings = { "type": "OneBitLamb", "params": { "bias_correction": True, "betas": [ args.beta1, args.beta2 ], "eps": args.eps, "max_coeff": 10.0, #0.3 "min_coeff": 0.01, "eps_inside_sqrt": False, "freeze_step": args.warmup, # "comm_backend_name": "nccl", # "coeff_beta": 0.9, # "factor_max": 4.0, # "factor_min": 0.5, # "factor_threshold": 0.1 } } with open(args.deepspeed_config, mode="w") as writer: ds_config = { "train_batch_size": args.batch_size * world_size * args.accum_freq, "train_micro_batch_size_per_gpu": args.batch_size, "gradient_accumulation_steps": args.accum_freq, "gradient_accumulation_dtype": "fp32", "steps_per_print": 1000000, "zero_allow_untested_optimizer": True, "fp16": { "enabled": True if args.precision != "bf16" else False, # "auto_cast": True, "loss_scale": 0, "initial_scale_power": 0, "loss_scale_window": 1000, "hysteresis": 2, "min_loss_scale": 1 }, "bf16": { "enabled": args.precision == "bf16" }, "amp": { "enabled": False, "opt_level": "O2" }, "flops_profiler": { "enabled": True, "profile_step": -1, "module_depth": -1, "top_modules": 1, "detailed": True, }, "activation_checkpointing": { "partition_activations": args.grad_checkpointing, "contiguous_memory_optimization": False, "profile": True }, # "wallclock_breakdown": True } if optim_settings is not None: ds_config.update({'optimizer': optim_settings}) if args.grad_clip_norm is not None: ds_config.update({'gradient_clipping': args.grad_clip_norm}) if args.zero_stage == 1: ds_config.update( { "zero_optimization": { "stage": 1, "reduce_bucket_size": 5e8, } } ) elif args.zero_stage == 2: ds_config.update( { "zero_optimization": { "stage": 2, "contiguous_gradients": ('vit-b' not in args.model.lower()), # should be False if model is small, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 5e8, "allgather_bucket_size": 5e8, "cpu_offload": False } } ) elif args.zero_stage == 3: ds_config.update( { "zero_optimization": { "stage": 3, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 5e4, "allgather_bucket_size": 5e4, "cpu_offload": False, }, "stage3_max_live_parameters": 1e5, "stage3_max_reuse_distance": 1e5, } ) elif args.zero_stage > 3: raise NotImplementedError() writer.write(json.dumps(ds_config, indent=2)) ================================================ FILE: inf_clip/train/utils.py ================================================ import os import time import logging import subprocess import multiprocessing import torch import torch.distributed as dist import fsspec from tqdm import tqdm from contextlib import suppress try: import horovod.torch as hvd except ImportError: hvd = None def setup_logging(log_file, level, include_host=False): if include_host: import socket hostname = socket.gethostname() formatter = logging.Formatter( f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') else: formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') logging.root.setLevel(level) loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] for logger in loggers: logger.setLevel(level) stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logging.root.addHandler(stream_handler) if log_file: file_handler = logging.FileHandler(filename=log_file) file_handler.setFormatter(formatter) logging.root.addHandler(file_handler) def remote_sync_s3(local_dir, remote_dir): # skip epoch_latest which can change during sync. result = subprocess.run(["aws", "s3", "sync", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.returncode != 0: logging.error(f"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}") return False logging.info(f"Successfully synced with S3 bucket") return True def remote_sync_fsspec(local_dir, remote_dir): # FIXME currently this is slow and not recommended. Look into speeding up. a = fsspec.get_mapper(local_dir) b = fsspec.get_mapper(remote_dir) for k in a: # skip epoch_latest which can change during sync. if 'epoch_latest.pt' in k: continue logging.info(f'Attempting to sync {k}') if k in b and len(a[k]) == len(b[k]): logging.debug(f'Skipping remote sync for {k}.') continue try: logging.info(f'Successful sync for {k}.') b[k] = a[k] except Exception as e: logging.info(f'Error during remote sync for {k}: {e}') return False return True def remote_sync(local_dir, remote_dir, protocol): logging.info('Starting remote sync.') if protocol == 's3': return remote_sync_s3(local_dir, remote_dir) elif protocol == 'fsspec': return remote_sync_fsspec(local_dir, remote_dir) else: logging.error('Remote protocol not known') return False def keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol): while True: time.sleep(sync_every) remote_sync(local_dir, remote_dir, protocol) def start_sync_process(sync_every, local_dir, remote_dir, protocol): p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol)) return p # Note: we are not currently using this save function. def pt_save(pt_obj, file_path): of = fsspec.open(file_path, "wb") with of as f: torch.save(pt_obj, file_path) def pt_load(file_path, map_location=None): if file_path.startswith('s3'): logging.info('Loading remote checkpoint, which may take a bit.') of = fsspec.open(file_path, "rb") with of as f: out = torch.load(f, map_location=map_location) return out def check_exists(file_path): try: with fsspec.open(file_path): pass except FileNotFoundError: return False return True def get_autocast(precision): if precision == 'amp': return lambda: torch.amp.autocast("cuda", dtype=torch.bfloat16) elif precision == 'amp_bfloat16' or precision == 'amp_bf16': # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.amp.autocast("cuda", dtype=torch.bfloat16) else: return suppress ################################## # Distributed training utilities # ################################## def is_global_master(args): return args.rank == 0 def is_local_master(args): return args.local_rank == 0 def is_master(args, local=False): return is_local_master(args) if local else is_global_master(args) def is_using_horovod(): # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] pmi_vars = ["PMI_RANK", "PMI_SIZE"] if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): return True else: return False def is_using_distributed(): if 'WORLD_SIZE' in os.environ: return int(os.environ['WORLD_SIZE']) > 1 if 'SLURM_NTASKS' in os.environ: return int(os.environ['SLURM_NTASKS']) > 1 return False def world_info_from_env(): local_rank = 0 for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): if v in os.environ: local_rank = int(os.environ[v]) break global_rank = 0 for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): if v in os.environ: global_rank = int(os.environ[v]) break world_size = 1 for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): if v in os.environ: world_size = int(os.environ[v]) break return local_rank, global_rank, world_size def init_distributed_device(args): # Distributed training = training on more than one GPU. # Works in both single and multi-node scenarios. args.distributed = False args.world_size = 1 args.rank = 0 # global rank args.local_rank = 0 if args.horovod: assert hvd is not None, "Horovod is not installed" hvd.init() args.local_rank = int(hvd.local_rank()) args.rank = hvd.rank() args.world_size = hvd.size() args.distributed = True os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) elif is_using_distributed(): if 'SLURM_PROCID' in os.environ: # DDP via SLURM args.local_rank, args.rank, args.world_size = world_info_from_env() # SLURM var -> torch.distributed vars in case needed os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) else: # DDP via torchrun, torch.distributed.launch args.local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True else: # DDP via torchrun, torch.distributed.launch args.local_rank, _, _ = world_info_from_env() torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True if torch.cuda.is_available(): if args.distributed and not args.no_set_device_rank: device = 'cuda:%d' % args.local_rank else: device = 'cuda:0' torch.cuda.set_device(device) else: device = 'cpu' args.device = device device = torch.device(device) return device def broadcast_object(args, obj, src=0): # broadcast a pickle-able python object from rank-0 to all ranks if args.horovod: return hvd.broadcast_object(obj, root_rank=src) else: if args.rank == src: objects = [obj] else: objects = [None] dist.broadcast_object_list(objects, src=src) return objects[0] def all_gather_object(args, obj, dst=0): # gather a pickle-able python object across all ranks if args.horovod: return hvd.allgather_object(obj) else: objects = [None for _ in range(args.world_size)] dist.all_gather_object(objects, obj) return objects ================================================ FILE: inf_clip/utils.py ================================================ from itertools import repeat import collections.abc import torch from torch import nn as nn from torchvision.ops.misc import FrozenBatchNorm2d def freeze_batch_norm_2d(module, module_match={}, name=''): """ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and returned. Otherwise, the module is walked recursively and submodules are converted in place. Args: module (torch.nn.Module): Any PyTorch module. module_match (dict): Dictionary of full module names to freeze (all if empty) name (str): Full module name (prefix) Returns: torch.nn.Module: Resulting module Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module is_match = True if module_match: is_match = name in module_match if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): res = FrozenBatchNorm2d(module.num_features) res.num_features = module.num_features res.affine = module.affine if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for child_name, child in module.named_children(): full_child_name = '.'.join([name, child_name]) if name else child_name new_child = freeze_batch_norm_2d(child, module_match, full_child_name) if new_child is not child: res.add_module(child_name, new_child) return res # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) to_3tuple = _ntuple(3) to_4tuple = _ntuple(4) to_ntuple = lambda n, x: _ntuple(n)(x) # Replaces all linear layers with linear_replacement # TODO: add int8 support for other linear layers including attn and convnets def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_linear(module, linear_replacement, include_modules, copy_weights) if isinstance(module, torch.nn.Linear) and name in include_modules: old_module = model._modules[name] model._modules[name] = linear_replacement( module.in_features, module.out_features, module.bias is not None, ) if copy_weights: model._modules[name].weight.data.copy_(old_module.weight.data) if model._modules[name].bias is not None: model._modules[name].bias.data.copy_(old_module.bias) return model def convert_int8_model_to_inference_mode(model): for m in model.modules(): if hasattr(m, 'prepare_for_eval'): int8_original_dtype = m.weight.dtype m.prepare_for_eval() m.int8_original_dtype = int8_original_dtype ================================================ FILE: inf_clip/zero_shot_classifier.py ================================================ from functools import partial from itertools import islice from typing import Callable, List, Optional, Sequence, Union import torch import torch.nn.functional as F def batched(iterable, n): """Batch data into lists of length *n*. The last batch may be shorter. NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl """ it = iter(iterable) while True: batch = list(islice(it, n)) if not batch: break yield batch def build_zero_shot_classifier( model, tokenizer, classnames: Sequence[str], templates: Sequence[Union[Callable, str]], num_classes_per_batch: Optional[int] = 10, device: Union[str, torch.device] = 'cpu', use_tqdm: bool = False, ): """ Build zero-shot classifier weights by iterating over class names in batches Args: model: CLIP model instance tokenizer: CLIP tokenizer instance classnames: A sequence of class (label) names templates: A sequence of callables or format() friendly strings to produce templates per class name num_classes_per_batch: The number of classes to batch together in each forward, all if None device: Device to use. use_tqdm: Enable TQDM progress bar. """ assert isinstance(templates, Sequence) and len(templates) > 0 assert isinstance(classnames, Sequence) and len(classnames) > 0 use_format = isinstance(templates[0], str) num_templates = len(templates) num_classes = len(classnames) if use_tqdm: import tqdm num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) else: iter_wrap = iter def _process_batch(batch_classnames): num_batch_classes = len(batch_classnames) texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] texts = tokenizer(texts).to(device) class_embeddings = model.encode_text(texts, normalize=True) class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) class_embeddings = class_embeddings.T return class_embeddings with torch.no_grad(): if num_classes_per_batch: batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] zeroshot_weights = torch.cat(batched_embeds, dim=1) else: zeroshot_weights = _process_batch(classnames) return zeroshot_weights def build_zero_shot_classifier_legacy( model, tokenizer, classnames: Sequence[str], templates: Sequence[Union[Callable, str]], device: Union[str, torch.device] = 'cpu', use_tqdm: bool = False, ): """ Build zero-shot classifier weights by iterating over class names 1 by 1 Args: model: CLIP model instance tokenizer: CLIP tokenizer instance classnames: A sequence of class (label) names templates: A sequence of callables or format() friendly strings to produce templates per class name device: Device to use. use_tqdm: Enable TQDM progress bar. """ assert isinstance(templates, Sequence) and len(templates) > 0 assert isinstance(classnames, Sequence) and len(classnames) > 0 if use_tqdm: import tqdm iter_wrap = tqdm.tqdm else: iter_wrap = iter use_format = isinstance(templates[0], str) with torch.no_grad(): zeroshot_weights = [] for classname in iter_wrap(classnames): texts = [template.format(classname) if use_format else template(classname) for template in templates] texts = tokenizer(texts).to(device) # tokenize class_embeddings = model.encode_text(texts) class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) class_embedding /= class_embedding.norm() zeroshot_weights.append(class_embedding) zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) return zeroshot_weights ================================================ FILE: inf_clip/zero_shot_metadata.py ================================================ OPENAI_IMAGENET_TEMPLATES = ( lambda c: f'a bad photo of a {c}.', lambda c: f'a photo of many {c}.', lambda c: f'a sculpture of a {c}.', lambda c: f'a photo of the hard to see {c}.', lambda c: f'a low resolution photo of the {c}.', lambda c: f'a rendering of a {c}.', lambda c: f'graffiti of a {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a cropped photo of the {c}.', lambda c: f'a tattoo of a {c}.', lambda c: f'the embroidered {c}.', lambda c: f'a photo of a hard to see {c}.', lambda c: f'a bright photo of a {c}.', lambda c: f'a photo of a clean {c}.', lambda c: f'a photo of a dirty {c}.', lambda c: f'a dark photo of the {c}.', lambda c: f'a drawing of a {c}.', lambda c: f'a photo of my {c}.', lambda c: f'the plastic {c}.', lambda c: f'a photo of the cool {c}.', lambda c: f'a close-up photo of a {c}.', lambda c: f'a black and white photo of the {c}.', lambda c: f'a painting of the {c}.', lambda c: f'a painting of a {c}.', lambda c: f'a pixelated photo of the {c}.', lambda c: f'a sculpture of the {c}.', lambda c: f'a bright photo of the {c}.', lambda c: f'a cropped photo of a {c}.', lambda c: f'a plastic {c}.', lambda c: f'a photo of the dirty {c}.', lambda c: f'a jpeg corrupted photo of a {c}.', lambda c: f'a blurry photo of the {c}.', lambda c: f'a photo of the {c}.', lambda c: f'a good photo of the {c}.', lambda c: f'a rendering of the {c}.', lambda c: f'a {c} in a video game.', lambda c: f'a photo of one {c}.', lambda c: f'a doodle of a {c}.', lambda c: f'a close-up photo of the {c}.', lambda c: f'a photo of a {c}.', lambda c: f'the origami {c}.', lambda c: f'the {c} in a video game.', lambda c: f'a sketch of a {c}.', lambda c: f'a doodle of the {c}.', lambda c: f'a origami {c}.', lambda c: f'a low resolution photo of a {c}.', lambda c: f'the toy {c}.', lambda c: f'a rendition of the {c}.', lambda c: f'a photo of the clean {c}.', lambda c: f'a photo of a large {c}.', lambda c: f'a rendition of a {c}.', lambda c: f'a photo of a nice {c}.', lambda c: f'a photo of a weird {c}.', lambda c: f'a blurry photo of a {c}.', lambda c: f'a cartoon {c}.', lambda c: f'art of a {c}.', lambda c: f'a sketch of the {c}.', lambda c: f'a embroidered {c}.', lambda c: f'a pixelated photo of a {c}.', lambda c: f'itap of the {c}.', lambda c: f'a jpeg corrupted photo of the {c}.', lambda c: f'a good photo of a {c}.', lambda c: f'a plushie {c}.', lambda c: f'a photo of the nice {c}.', lambda c: f'a photo of the small {c}.', lambda c: f'a photo of the weird {c}.', lambda c: f'the cartoon {c}.', lambda c: f'art of the {c}.', lambda c: f'a drawing of the {c}.', lambda c: f'a photo of the large {c}.', lambda c: f'a black and white photo of a {c}.', lambda c: f'the plushie {c}.', lambda c: f'a dark photo of a {c}.', lambda c: f'itap of a {c}.', lambda c: f'graffiti of the {c}.', lambda c: f'a toy {c}.', lambda c: f'itap of my {c}.', lambda c: f'a photo of a cool {c}.', lambda c: f'a photo of a small {c}.', lambda c: f'a tattoo of the {c}.', ) # a much smaller subset of above prompts # from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb SIMPLE_IMAGENET_TEMPLATES = ( lambda c: f'itap of a {c}.', lambda c: f'a bad photo of the {c}.', lambda c: f'a origami {c}.', lambda c: f'a photo of the large {c}.', lambda c: f'a {c} in a video game.', lambda c: f'art of the {c}.', lambda c: f'a photo of the small {c}.', ) IMAGENET_CLASSNAMES = ( "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" ) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" [project] name = "inf-cl" version = "1.2" authors = [ {name = "Zesen Cheng", email = "cyanlaser@stu.pku.edu.cn"}, {name = "Hang Zhang"}, {name = "Kehan Li"}, {name = "Xin Li"}, ] description = "A highly memory-efficient contrastive loss." readme = "README.md" requires-python = ">=3.8" license = {text = "MIT"} classifiers = [ 'Development Status :: 4 - Beta', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: MIT License', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ] dependencies = [ 'numpy', 'triton>=2.2.0', ] [project.urls] Homepage = "https://github.com/clownrat6/Inf-CLIP/inf_cl" Issues = "https://github.com/clownrat6/Inf-CLIP/issues" [tool.pdm.build] excludes = ["./.git"] package-dir = "." includes = ["./inf_cl"] ================================================ FILE: requirements.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu118 # basic dependencies torch==2.2.0 torchvision==0.17.0 numpy==1.24.4 timm # data processing webdataset pandas ftfy regex braceexpand # The newest pillow fix this bug: "UserWarning: image file could not be identified because WEBP support not installed" pillow==10.4.0 # Refer to this issue: https://github.com/ContinuumIO/anaconda-issues/issues/10737 # logging tools tensorboard tensorboardX ================================================ FILE: scripts/benchmarks_eval.sh ================================================ clip_benchmark eval \ --model LiT-B-16 \ --pretrained work_dirs/epoch_8.pt \ --dataset datasets/imagenet.txt \ --recall_k 1 5 10 \ --dataset_root datasets/clip-benchmark/wds_{dataset_cleaned} \ --output "benchmark_{dataset}_{pretrained}_{model}_{language}_{task}.json" ================================================ FILE: scripts/cc12m/clip_vit-b-32_bs32k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=32768 LOCAL_BATCH_SIZE=512 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=40 TRAIN_NUM_SAMPLES=10445970 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=clip_cc12m RUN_NAME=vit-b-32_bs32k_e40 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model ViT-B-32 \ --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/IMAGE/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --lr 5e-4 \ --beta1 0.9 \ --beta2 0.98 \ --eps 1.0e-8 \ --wd 0.5 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 5 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ ================================================ FILE: scripts/cc12m/lit_vit-b-16_bs32k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=32768 LOCAL_BATCH_SIZE=1024 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=20 TRAIN_NUM_SAMPLES=10445970 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_cc12m RUN_NAME=lit-b-16_bs32k_e20 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-B-16 \ --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/cc12m/lit_vit-b-32_bs32k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=32768 LOCAL_BATCH_SIZE=1024 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=20 TRAIN_NUM_SAMPLES=10445970 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_cc12m RUN_NAME=lit-b-32_bs32k_e20 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-B-32 \ --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/cc3m/clip_r50_bs4k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=4096 LOCAL_BATCH_SIZE=256 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=40 TRAIN_NUM_SAMPLES=3018714 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=clip_cc3m RUN_NAME=r50_bs4k_e40 DATA_DIR=/mnt/damovl/MEDIA OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model RN50 \ --train-data ${DATA_DIR}'/cc3m/{0000..0301}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --lr 5e-4 \ --beta1 0.9 \ --beta2 0.98 \ --eps 1.0e-8 \ --wd 0.5 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 5 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/cc3m/clip_vit-b-32_bs16k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=16384 LOCAL_BATCH_SIZE=256 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=40 TRAIN_NUM_SAMPLES=3018714 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=clip_cc3m RUN_NAME=vit-b-32_bs16k_e40 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model ViT-B-32 \ --train-data ${DATA_DIR}'/cc3m/cc3m-train-{0000..0575}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --lr 5e-4 \ --beta1 0.9 \ --beta2 0.98 \ --eps 1.0e-8 \ --wd 0.5 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 5 \ --log_dir $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/cc3m/lit_vit-b-32_bs16k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=16384 LOCAL_BATCH_SIZE=256 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=20 TRAIN_NUM_SAMPLES=3018714 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_cc3m RUN_NAME=lit-b-32_bs16k_e20 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-B-32 \ --train-data ${DATA_DIR}'/cc3m/{0000..1044}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 32 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/imagenet_eval.sh ================================================ torchrun --nproc_per_node 1 \ -m inf_cl_train.main \ --imagenet-val datasets/imagenet-1k/val \ --model ViT-B-16 \ --pretrained openai \ --workers 64 \ ================================================ FILE: scripts/laion400m/clip_vit-b-32_bs256k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=262144 LOCAL_BATCH_SIZE=512 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=8 TRAIN_NUM_SAMPLES=280321756 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=clip_laion400m RUN_NAME=vit-b-32_bs256k_e8 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model ViT-B-32 \ --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --lr 5e-4 \ --beta1 0.9 \ --beta2 0.98 \ --eps 1.0e-8 \ --wd 0.5 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ ================================================ FILE: scripts/laion400m/lit_vit-b-16_bs256k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=262144 LOCAL_BATCH_SIZE=512 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=8 TRAIN_NUM_SAMPLES=280321756 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_laion400m RUN_NAME=lit-b-16_bs256k_e8 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-B-16 \ --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/laion400m/lit_vit-b-32_bs256k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=262144 LOCAL_BATCH_SIZE=512 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=8 TRAIN_NUM_SAMPLES=280321756 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_laion400m RUN_NAME=lit-b-32_bs256k_e8 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-B-32 \ --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: scripts/laion400m/lit_vit-l-16_bs256k.sh ================================================ # Environment Variables ARG_WORLD_SIZE=${1:-1} ARG_NPROC_PER_NODE=${2:-8} ARG_MASTER_ADDR="127.0.0.1" ARG_MASTER_PORT=16666 ARG_RANK=${3:-0} # Multiple conditions if [ ! -n "$WORLD_SIZE" ] || [ ! -n "$NPROC_PER_NODE" ]; then WORLD_SIZE=$ARG_WORLD_SIZE NPROC_PER_NODE=$ARG_NPROC_PER_NODE fi if [ ! -n "$MASTER_ADDR" ] || [ ! -n "$MASTER_PORT" ] || [ ! -n "$RANK" ]; then MASTER_ADDR=$ARG_MASTER_ADDR MASTER_PORT=$ARG_MASTER_PORT RANK=$ARG_RANK fi echo "WORLD_SIZE: $WORLD_SIZE" echo "NPROC_PER_NODE: $NPROC_PER_NODE" # Training Arguments GLOBAL_BATCH_SIZE=262144 LOCAL_BATCH_SIZE=512 ACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)] EPOCHS=8 TRAIN_NUM_SAMPLES=280321756 WARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)] echo "ACCUMULATION_STEPS: $ACCUMULATION_STEPS" # Log Arguments export TRANSFORMERS_OFFLINE=1 export WANDB_PROJECT=lit_laion400m RUN_NAME=lit-l-16_bs256k_e8 DATA_DIR=datasets OUTP_DIR=work_dirs torchrun --nnodes $WORLD_SIZE \ --nproc_per_node $NPROC_PER_NODE \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ --node_rank $RANK \ -m inf_clip.train.main \ --model LiT-L-16 \ --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \ --train-num-samples $TRAIN_NUM_SAMPLES \ --aug-cfg scale='(0.08, 1.0)'\ --dataset-type webdataset \ --imagenet-val ${DATA_DIR}/imagenet-1k/val \ --epochs $EPOCHS \ --warmup $WARMUP_STEPS \ --batch-size $LOCAL_BATCH_SIZE \ --accum-freq $ACCUMULATION_STEPS \ --optim adafactor \ --lr 1e-3 \ --beta1 0.9 \ --beta2 0.95 \ --eps 1.0e-8 \ --wd 1e-4 \ --grad-clip-norm 1.0 \ --workers 16 \ --precision amp \ --infloss \ --log-every-n-steps 1 \ --logs $OUTP_DIR/$WANDB_PROJECT \ --name $RUN_NAME \ --save-frequency 1 \ --zeroshot-frequency 1 \ --report-to tensorboard \ --resume latest \ ================================================ FILE: tests/example.py ================================================ import torch import torch.nn.functional as F import torch.distributed as dist import numpy as np from inf_cl import cal_inf_loss def create_cl_tensors(rank, world_size): # Parameters dtype = torch.float32 num_heads = 3 # Number of attention heads seq_length_q = 32768 # Sequence length seq_length_k = 32768 d_model = 256 # Dimension of each head (must be 16, 32, 64, or 128) # Randomly initialize inputs q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}") k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}") l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07) q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key l = l.requires_grad_() # Logit scale return q, k, l if __name__ == "__main__": # Assume that the distributed environment has been initialized dist.init_process_group("nccl") rank = dist.get_rank() world_size = dist.get_world_size() torch.cuda.set_device(rank) # Exampled by Image-Text Contrastive Learning, q is the global image features, # k is the text features, and l is the logit scale. q, k, l = create_cl_tensors(rank, world_size) # labels are diagonal elements by default. # labels = torch.arange(q.shape[0]) loss = cal_inf_loss(q, k, scale=l.exp()) print(loss)