[
  {
    "path": ".gitattributes",
    "content": "*.py linguist-language=python\n*.ipynb linguist-documentation\n"
  },
  {
    "path": ".gitignore",
    "content": "**/logs/\n**/wandb/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\nsync.sh\ngpu1sync.sh\n.idea\n*.pdf\n**/._*\n**/*DS_*\n**.jsonl\nsrc/sbatch\nsrc/misc\n.vscode\nsrc/debug\ncore.*\n\n# Allow\n!src/evaluation/misc/results_dbs/*\n\n# log dirs\n/work_dirs*/\n/datasets/\n\n# oss logs\n/.ossutil*\n/ossutil*"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n    <img src=\"https://github.com/user-attachments/assets/53a09bd1-c8ac-43c0-80ae-03ba284c94ad\" width=\"150\" style=\"margin-bottom: 0.2;\"/>\n<p>\n\n<h3 align=\"center\"><a href=\"https://arxiv.org/abs/2410.17243\">\nBreaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss</a></h3>\n<h5 align=\"center\"> If our project helps you, please give us a star ⭐ on GitHub to support us. 🙏🙏 </h2>\n\n<h5 align=\"center\">\n\n[![arXiv](https://img.shields.io/badge/Arxiv-2410.17243-AD1C18.svg?logo=arXiv)](https://arxiv.org/abs/2410.17243)\n[![hf_paper](https://img.shields.io/badge/🤗-Paper%20In%20HF-red.svg)](https://huggingface.co/papers/2410.17243)\n[![PyPI](https://img.shields.io/badge/PyPI-Inf--CL-9C276A.svg)](https://pypi.org/project/inf-cl) <br>\n[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/LICENSE)\n[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FDAMO-NLP-SG%2FInf-CLIP&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false)](https://hits.seeyoufarm.com)\n[![GitHub issues](https://img.shields.io/github/issues/DAMO-NLP-SG/Inf-CLIP?color=critical&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aopen+is%3Aissue)\n[![GitHub closed issues](https://img.shields.io/github/issues-closed/DAMO-NLP-SG/Inf-CLIP?color=success&label=Issues)](https://github.com/DAMO-NLP-SG/Inf-CLIP/issues?q=is%3Aissue+is%3Aclosed)  <br>\n[![zhihu](https://img.shields.io/badge/-知乎-000000?logo=zhihu&logoColor=0084FF)](https://zhuanlan.zhihu.com/p/1681887214)\n[![Twitter](https://img.shields.io/badge/-Twitter-black?logo=twitter&logoColor=1D9BF0)](https://x.com/lixin4ever/status/1849669129613226457) <br>\n\n</h5>\n\n<div align=\"center\"><img src=\"https://github.com/user-attachments/assets/2c19838b-43d8-4145-b28c-903f3d76f8ab\" width=\"800\" /></div>\n\n<details open><summary>💡 Some other multimodal foundation model projects from our team may interest you ✨. </summary><p>\n<!--  may -->\n\n> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922) <br>\n> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing <br>\n[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VCD)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VCD.svg?style=social)](https://github.com/DAMO-NLP-SG/VCD)  [![arXiv](https://img.shields.io/badge/Arxiv-2311.16922-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.16922) <br>\n\n> [**VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs**](https://github.com/DAMO-NLP-SG/VideoLLaMA2) <br>\n> Zesen Cheng, Sicong Leng, Hang Zhang, Yifei Xin, Xin Li, Guanzheng Chen, Yongxin Zhu, Wenqi Zhang, Ziyang Luo, Deli Zhao, Lidong Bing <br>\n[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VideoLLaMA2)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA2.svg?style=social)](https://github.com/DAMO-NLP-SG/VideoLLaMA2) [![arXiv](https://img.shields.io/badge/Arxiv-2406.07476-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2406.07476) <br>\n\n> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787) <br>\n> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing <br>\n[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/CMM)  [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/CMM.svg?style=social)](https://github.com/DAMO-NLP-SG/CMM)  [![arXiv](https://img.shields.io/badge/Arxiv-2410.12787-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.12787) <br>\n\n</p></details>\n\n## 📰 News\n* **[2024.10.18]**  Release training and evaluation codes of Inf-CLIP.\n\n<div align=\"center\"><img src=\"https://github.com/user-attachments/assets/11c5cc32-aac2-497d-bbc1-33e065a71be0\" width=\"800\" /></div>\n\n## 🛠️ Requirements and Installation\n\nBasic Dependencies:\n* Python >= 3.8\n* Pytorch >= 2.0.0\n* CUDA Version >= 11.8\n\n[Remote] Install Inf-CL:\n```bash\n# remote installing\npip install inf_cl -i https://pypi.org/simple\n```\n\n[Local] Install Inf-CL:\n```bash\npip install -e .\n```\n\nInstall required packages:\n```bash\ngit clone https://github.com/DAMO-NLP-SG/Inf-CLIP\ncd Inf-CLIP\npip install -r requirements.txt\n```\n\n## ⭐ Features\n\n`inf_cl` is the triton implementation of Inf-CL loss:\n* [x] [Ring-CL (inf_cl/ring.py#L238)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L238)\n* [x] [Inf-CL  (inf_cl/ring.py#L251)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip/models/ops/ring.py#L251)\n\n`inf_clip` is the CLIP training codebase with Inf-CL loss and other training features:\n- [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](https://github.com/DAMO-NLP-SG/Inf-CLIP/inf_clip_train/train.py#L180)\n- [x] [Gradient Cache (inf_clip/train/train.py#L292)](https://github.com/DAMO-NLP-SG/Inf-CLIP/blob/main/inf_clip_train/train.py#L292)\n\n\n## 🔑 Usage\n\nA simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:\n```\ntorchrun --nproc_per_node 2 tests/example.py\n```\n\n```python\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport numpy as np\n\nfrom inf_cl import cal_inf_loss\n\n\ndef create_cl_tensors(rank, world_size):\n    # Parameters\n    dtype = torch.float32\n    num_heads = 3        # Number of attention heads\n    seq_length_q = 32768 # Sequence length\n    seq_length_k = 32768\n    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)\n\n    # Randomly initialize inputs\n    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda:{rank}\")\n    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda:{rank}\")\n    l = torch.ones([], dtype=dtype, device=f\"cuda:{rank}\") * np.log(1 / 0.07)\n\n    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query\n    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key\n    l = l.requires_grad_() # Logit scale\n\n    return q, k, l\n\n\nif __name__ == \"__main__\":\n    # Assume that the distributed environment has been initialized\n    dist.init_process_group(\"nccl\")\n\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n\n    torch.cuda.set_device(rank)\n\n    # Exampled by Image-Text Contrastive Learning, q is the global image features, \n    # k is the text features, and l is the logit scale.\n    q, k, l = create_cl_tensors(rank, world_size)\n\n    # labels are diagonal elements by default. \n    # labels = torch.arange(q.shape[0])\n    loss = cal_inf_loss(q, k, scale=l.exp())\n\n    print(loss)\n\n```\n\n## 🚀 Main Results\n\n### Memory Cost\n<p><img src=\"https://github.com/user-attachments/assets/05dd3fea-0a93-4716-b321-0a94965e1fbe\" width=\"800\" \"/></p>\n\n\\* denotes adopting \"data offload\" strategy. \n\n### Max Supported Batch Size\n<p><img src=\"https://github.com/user-attachments/assets/eb38fb90-3b7e-4696-b078-b7766893f758\" width=\"800\" \"/></p>\n\n### Speed\n<p><img src=\"https://github.com/user-attachments/assets/da72e99b-508b-450a-b12e-401d4991291a\" width=\"800\" \"/></p>\n\n### Batch Size Scaling\n<p><img src=\"https://github.com/user-attachments/assets/5b55fa98-6558-4509-9b66-e290ecf77b41\" width=\"800\" \"/></p>\n\nTraining with larger data scale needs larger batch size.\n\n## 🗝️ Training & Evaluation\n\n### Quick Start\n\nTo facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.\n\n1. Training Data Structure:\n```bash\nInf-CLIP\n├── datasets\n│   ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md\n|   |   ├── 0000.tar\n|   |   ├── 0001.tar\n|   |   ├── ...\n|   |   └── 0301.tar\n│   ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md\n|   |   ├── 0000.tar\n|   |   ├── 0001.tar\n|   |   ├── ...\n|   |   └── 1044.tar\n│   ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md\n|   |   ├── 00000.tar\n|   |   ├── 00001.tar\n|   |   ├── ...\n|   |   └── 41407.tar\n```\n2. Command:\n```bash\nbash scripts/cc3m/lit_vit-b-32_bs16k.sh\nbash scripts/cc12m/lit_vit-b-32_bs32k.sh\nbash scripts/laion400m/lit_vit-b-32_bs256k.sh\n```\n3. Evaluation Data Structure:\n```bash\nInf-CLIP\n├── datasets\n│   ├── imagenet-1k/ # download val_images.tar.gz of imagenet from https://huggingface.co/datasets/ILSVRC/imagenet-1k/tree/main/data\n|   |   └── val/ # python datasets/reformat_imagenet.py\n|   |   |   ├── n01440764\n|   |   |   ├── n01443537\n|   |   |   ├── ...\n|   |   |   └── n15075141\n│   ├── clip-benchmark/ # bash datasets/benchmarks_download.sh\n|   |   ├── wds_mscoco_captions\n|   |   ├── wds_flickr8k\n|   |   ├── wds_flickr30k\n|   |   ├── wds_imagenet1k\n|   |   ├── wds_imagenetv2\n|   |   ├── wds_imagenet_sketch\n|   |   ├── wds_imagenet-a\n|   |   ├── wds_imagenet-r\n|   |   ├── wds_imagenet-o\n|   |   └── wds_objectnet\n```\n4. Command:\n```bash\n# imagenet evaluation\nbash scripts/imagenet_eval.sh\n# overall evaluation\nbash scripts/benchmarks_eval.sh\n```\n\n## 📑 Citation\n\nIf you find Inf-CLIP useful for your research and applications, please cite using this BibTeX:\n```bibtex\n@article{damovl2024infcl,\n  title={Breaking the Memory Barrier: Near Infinite Batch Size Scaling for Contrastive Loss},\n  author={Zesen Cheng, Hang Zhang, Kehan Li, Sicong Leng, Zhiqiang Hu, Fei Wu, Deli Zhao, Xin Li, Lidong Bing},\n  journal={arXiv preprint arXiv:2410.17243},\n  year={2024},\n  url={https://arxiv.org/abs/2410.12787}\n}\n```\n\n## 👍 Acknowledgement\nThe codebase of Inf-CLIP is adapted from [**OpenCLIP**](https://github.com/mlfoundations/open_clip). We are also grateful for the following projects our Inf-CL arose from:\n* [**OpenAI CLIP**](https://openai.com/index/clip/), [**img2dataset**](https://github.com/rom1504/img2dataset), [**CLIP-Benchmark**](https://github.com/LAION-AI/CLIP_benchmark).\n* [**FlashAttention**](https://github.com/Dao-AILab/flash-attention), [**RingAttention**](https://github.com/haoliuhl/ringattention), [**RingFlashAttention**](https://github.com/zhuzilin/ring-flash-attention). \n\n\n## 🔒 License\n\nThis project is released under the Apache 2.0 license as found in the LICENSE file.\nThe service is a research preview intended for **non-commercial use ONLY**, subject to the model Licenses of CLIP, Terms of Use of the data generated by OpenAI, and Laion. Please get in touch with us if you find any potential violations.\n"
  },
  {
    "path": "inf_cl/__init__.py",
    "content": "from .flash import cal_flash_loss\nfrom .ring  import cal_ring_loss, cal_inf_loss"
  },
  {
    "path": "inf_cl/flash.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _prob_fwd_kernel(\n    Q,\n    K,\n    LSE,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    BLOCK_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    # start index of sequence length\n    start_m = tl.program_id(0)\n\n    # initialize offsets\n    ndims = nheads * BLOCK_HEADDIM\n    offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n\n    # Initialize pointers to Q, K, V\n    q_ptrs = Q + ndims * offs_m[:, None]\n    k_ptrs = K + ndims * offs_n[:, None]\n    # initialize pointer to m and l\n    lse_i    = tl.zeros([BLOCK_M], dtype=tl.float32)\n    m_i      = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n\n    # loop over k, v and update accumulator\n    end_n = seqlen_k\n    for start_n in range(0, end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        for off_h in range(nheads):\n            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]\n            # -- fetch q and k of a single head ----\n            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)\n            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n            # -- compute qk ----\n            qk += tl.dot(q, tl.trans(k))\n\n        # Trying to combine the two masks seem to make the result wrong\n        m_ij = tl.maximum(tl.max(qk, 1), m_i)\n        p = tl.exp(qk - m_ij[:, None])\n        # Fix out of bound access\n        p = tl.where((start_n + offs_n)[None, :] < seqlen_k, p, 0.0)\n        # -- update statistics\n        lse_i = tl.exp(m_i - m_ij) * lse_i + tl.sum(p, 1)\n        m_i = m_ij\n\n    lse_i = m_i + tl.log(lse_i)\n    # mask out the padded values\n    lse_i = tl.where(offs_m < seqlen_q, lse_i, 0.0)\n\n    tl.store(LSE + offs_m, lse_i)\n\n\n@triton.jit\ndef _dq_prob_bwd_kernel(\n    Q,\n    K,\n    dQ,\n    LSE,\n    dLSE,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    BLOCK_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    ASM: tl.constexpr = \"cvt.rna.tf32.f32 $0, $1;\"\n    # start index of sequence length\n    start_m = tl.program_id(0)\n\n    # initialize offsets\n    ndims = nheads * BLOCK_HEADDIM\n    offs_m = tl.arange(0, BLOCK_M) + start_m * BLOCK_M\n    offs_n = tl.arange(0, BLOCK_N)\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n\n    # Initialize pointers to Q, K, V\n    q_ptrs  = Q  + ndims * offs_m[:, None]\n    dq_ptrs = dQ + ndims * offs_m[:, None]\n    k_ptrs  = K  + ndims * offs_n[:, None]\n    # setting lse\n    lse = tl.load(LSE + offs_m, mask=offs_m < seqlen_q, other=0.0)\n    dlse = tl.load(dLSE + offs_m, mask=offs_m < seqlen_q, other=0.0)\n\n    # loop over k, v and update accumulator\n    end_n = seqlen_k        \n    for start_n in range(0, end_n, BLOCK_N):\n        start_n = tl.multiple_of(start_n, BLOCK_N)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        for off_h in range(nheads):\n            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]\n            # -- fetch q and k of a single head ----\n            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)\n            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n            # -- compute qk ----\n            qk += tl.dot(q, tl.trans(k))\n\n        qk_grad = tl.exp(qk - lse[:, None])\n        qk_grad = tl.where((start_n + offs_n)[None, :] < seqlen_k, qk_grad, 0.0)\n        qk_grad = qk_grad * dlse[:, None]\n        qk_grad = tl.inline_asm_elementwise(ASM, \"=r, r\", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)\n        for off_h in range(nheads):\n            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]\n            # -- fetch q and k of a single head ----\n            q = tl.load(q_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)\n            k = tl.load(k_ptrs + offs_hd + start_n * ndims, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)\n            # -- compute q grad ----\n            # NOTE: tl.float32 adopt tf32, which causes precision inconsistency with torch\n            # A solution for this problem\n            # Refer to issue: https://github.com/triton-lang/triton/issues/4574\n            # if allow_tf32:\n            k = tl.inline_asm_elementwise(ASM, \"=r, r\", [k], dtype=tl.float32, is_pure=True, pack=1)\n            q_grad = tl.dot(qk_grad, k)\n            # Another solution for this problem\n            # Refer to https://github.com/triton-lang/triton/issues/376\n            # q_grad = tl.dot(qk_grad, k.to(tl.float32), allow_tf32=False)\n            # -- store dq ----\n            dq_h = tl.load(dq_ptrs + offs_hd, mask=offs_m[:, None] < seqlen_q, other=0.0)\n            tl.store(dq_ptrs + offs_hd, dq_h + q_grad, mask=offs_m[:, None] < seqlen_q)\n\n\n@triton.jit\ndef _dk_prob_bwd_kernel(\n    Q,\n    K,\n    dK,\n    LSE,\n    dLSE,\n    nheads,\n    seqlen_q,\n    seqlen_k,\n    BLOCK_HEADDIM: tl.constexpr,\n    BLOCK_M: tl.constexpr,\n    BLOCK_N: tl.constexpr,\n):\n    ASM: tl.constexpr = \"cvt.rna.tf32.f32 $0, $1;\"\n    # start index of sequence length\n    start_n = tl.program_id(0)\n\n    # initialize offsets\n    ndims = nheads * BLOCK_HEADDIM\n    offs_m = tl.arange(0, BLOCK_M)\n    offs_n = tl.arange(0, BLOCK_N) + start_n * BLOCK_N\n    offs_d = tl.arange(0, BLOCK_HEADDIM)\n\n    # Initialize pointers to Q, K, V\n    q_ptrs  = Q  + ndims * offs_m[:, None]\n    k_ptrs  = K  + ndims * offs_n[:, None]\n    dk_ptrs = dK + ndims * offs_n[:, None]\n\n    # loop over q and update accumulator\n    end_m = seqlen_q        \n    for start_m in range(0, end_m, BLOCK_M):\n        start_m = tl.multiple_of(start_m, BLOCK_M)\n\n        # setting lse\n        lse = tl.load(LSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)\n        dlse = tl.load(dLSE + offs_m + start_m, mask=offs_m < seqlen_q, other=0.0)\n\n        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n        for off_h in range(nheads):\n            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]\n            # -- fetch q and k of a single head ----\n            q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(offs_m + start_m)[:, None] < seqlen_q, other=0.0)\n            k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)\n            # -- compute qk ----\n            qk += tl.dot(q, tl.trans(k))\n\n        qk_grad = tl.exp(qk - lse[:, None])\n        qk_grad = tl.where((start_m + offs_m)[:, None] < seqlen_q, qk_grad, 0.0)\n        qk_grad = qk_grad * dlse[:, None]\n        qk_grad = tl.inline_asm_elementwise(ASM, \"=r, r\", [qk_grad], dtype=tl.float32, is_pure=True, pack=1)\n        for off_h in range(nheads):\n            offs_hd = (offs_d + off_h * BLOCK_HEADDIM)[None, :]\n            # -- fetch q and k of a single head ----\n            q = tl.load(q_ptrs + offs_hd + start_m * ndims, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0)\n            k = tl.load(k_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)\n            # -- compute k grad ----\n            q = tl.inline_asm_elementwise(ASM, \"=r, r\", [q], dtype=tl.float32, is_pure=True, pack=1)\n            k_grad = tl.dot(tl.trans(qk_grad), q)\n            # k_grad = tl.dot(tl.trans(qk_grad), q.to(tl.float32))\n            # -- store dk ----\n            dk_h = tl.load(dk_ptrs + offs_hd, mask=(offs_n)[:, None] < seqlen_k, other=0.0)\n            tl.store(dk_ptrs + offs_hd, dk_h + k_grad, mask=(offs_n)[:, None] < seqlen_k)\n\n\ndef _flash_prob_forward(q, k):\n    # shape constraints\n    seqlen_q, nheads, d = q.shape\n    seqlen_k, _, _ = k.shape\n    assert k.shape == (seqlen_k, nheads, d)\n    # assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n    assert q.dtype == k.dtype, \"All tensors must have the same type\"\n    # assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n    assert q.is_cuda and k.is_cuda\n\n    seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128\n    lse = torch.empty((seqlen_q_rounded), device=q.device, dtype=torch.float32)\n\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    BLOCK_M = 64\n    BLOCK_N = 64\n    num_warps = 8\n    num_stages = 1\n    grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), 1)\n    _prob_fwd_kernel[grid](\n        q,\n        k,\n        lse,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        BLOCK_HEADDIM,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    lse = lse[:seqlen_q]\n\n    return lse\n\n\ndef _flash_prob_backward(q, k, lse, dlse):\n    # shape constraints\n    seqlen_q, nheads, d = q.shape\n    seqlen_k, _, _ = k.shape\n    assert k.shape == (seqlen_k, nheads, d)\n    # assert d <= 128, \"FlashAttention only support head dimensions up to 128\"\n    assert q.dtype == k.dtype, \"All tensors must have the same type\"\n    # assert q.dtype in [torch.float16, torch.bfloat16], \"Only support fp16 and bf16\"\n    assert q.is_cuda and k.is_cuda\n\n    dq = torch.zeros_like(q, dtype=torch.float32)\n    dk = torch.zeros_like(k, dtype=torch.float32)\n\n    q = q.contiguous()\n    k = k.contiguous()\n    dlse = dlse.contiguous()\n\n    BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)\n    BLOCK_M = 64\n    BLOCK_N = 64\n    num_warps = 8\n    num_stages = 1\n    grid = lambda META: (triton.cdiv(seqlen_q, META[\"BLOCK_M\"]), 1)\n    _dq_prob_bwd_kernel[grid](\n        q,\n        k,\n        dq,\n        lse,\n        dlse,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        BLOCK_HEADDIM,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    BLOCK_N = BLOCK_M\n    BLOCK_M = BLOCK_N\n    grid = lambda META: (triton.cdiv(seqlen_k, META[\"BLOCK_N\"]), 1)\n    _dk_prob_bwd_kernel[grid](\n        q,\n        k,\n        dk,\n        lse,\n        dlse,\n        nheads,\n        seqlen_q,\n        seqlen_k,\n        BLOCK_HEADDIM,\n        BLOCK_M=BLOCK_M,\n        BLOCK_N=BLOCK_N,\n        num_warps=num_warps,\n        num_stages=num_stages,\n    )\n\n    dq = dq[:seqlen_q]\n    dk = dk[:seqlen_k]\n\n    return dq, dk\n\n\nclass FlashProb(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, k):\n        lse = _flash_prob_forward(q, k)\n        ctx.save_for_backward(q, k, lse)\n\n        return lse\n\n    @staticmethod\n    def backward(ctx, dlse):\n        q, k, lse = ctx.saved_tensors\n        dq, dk = _flash_prob_backward(q, k, lse, dlse)\n\n        return dq, dk\n\n\ndef _cal_flash_loss(q, k, labels, head_dim=256):\n    bq = q.shape[0]\n    bk = k.shape[0]\n    # NOTE: logits forward or backward should keep fp32 for better precision\n    q = q.view(bq, -1, head_dim).float()\n    k = k.view(bk, -1, head_dim).float()\n\n    lse = FlashProb.apply(q, k)\n    numerator = torch.einsum(\"mhd,mhd->m\", q, k[labels, ...])\n    loss = -numerator + lse\n\n    return loss\n\n\ndef cal_flash_loss(q, k, labels=None, scale=None, head_dim=256):\n    if labels is None:\n        labels = torch.arange(q.shape[0], device=q.device)\n    if scale is None:\n        scale = 1.0\n    return _cal_flash_loss(scale * q, k, labels, head_dim)\n\n\nif __name__ == '__main__':\n    import time\n\n    # Parameters\n    num_heads = 3        # Number of attention heads\n    seq_length_q = 32768 # Sequence length\n    seq_length_k = 32768\n    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)\n\n    # Randomly initialize inputs\n    q = torch.rand((seq_length_q, num_heads * d_model), dtype=torch.float32, device=\"cuda\") # Query\n    k = torch.rand((seq_length_k, num_heads * d_model), dtype=torch.float32, device=\"cuda\") # Key\n    l = torch.ones([], device=\"cuda\") * np.log(1 / 0.02); l.requires_grad = True\n\n    q = F.normalize(q, p=2, dim=-1); q.requires_grad = True\n    k = F.normalize(k, p=2, dim=-1); k.requires_grad = True\n\n    q1 = q.clone().detach().requires_grad_(True)\n    k1 = k.clone().detach().requires_grad_(True)\n    l1 = l.clone().detach().requires_grad_(True)\n\n    labels = torch.arange(seq_length_q).cuda()\n\n    for i in range(1000):\n\n        # A. torch gradient\n        start = time.time()\n        qk = torch.einsum(\"md,nd->mn\", l.exp() * q, k)\n        loss = F.cross_entropy(qk, labels, reduction=\"mean\")\n        loss.backward()\n        end = time.time()\n\n        # B. triton gradient\n        start1 = time.time()\n        loss1 = cal_flash_loss(q1, k1, labels, l1.exp())\n        loss1 = loss1.mean()\n        loss1.backward()\n        end1 = time.time()\n\n        print(\"========= Difference =========\")\n        print(end - start, end1 - start1, l.grad, l1.grad)\n        print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))\n\n        q.grad = None; k.grad = None; l.grad = None\n        q1.grad = None; k1.grad = None; l1.grad = None\n"
  },
  {
    "path": "inf_cl/ring.py",
    "content": "import os\nimport math\nimport random\n\nimport torch\nimport torch.distributed as dist\nimport torch.distributed.nn as dist_nn\nimport torch.nn.functional as F\nimport numpy as np\n\nimport triton\nimport triton.language as tl\n\nfrom .flash import _flash_prob_forward, _flash_prob_backward, _cal_flash_loss\n\n\nclass RingComm:\n\n    def __init__(self, process_group: dist.ProcessGroup):\n        self._process_group = process_group\n        self._ops = []\n        self.rank = dist.get_rank(self._process_group)\n        self.world_size = dist.get_world_size(self._process_group)\n        self._reqs = None\n\n        self.send_rank = (self.rank + 1) % self.world_size\n        self.recv_rank = (self.rank - 1) % self.world_size\n        # print(f'rank: {self.rank}, send_rank: {self.send_rank}, recv_rank: {self.recv_rank}')\n        if process_group is not None:\n            self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)\n            self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)\n\n    def send_recv(self, to_send, recv_tensor = None):\n        if recv_tensor is None:\n            res = torch.empty_like(to_send)\n        else:\n            res = recv_tensor\n\n        send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)\n        recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)\n        self._ops.append(send_op)\n        self._ops.append(recv_op)\n        return res\n\n    def commit(self):\n        if self._reqs is not None:\n            raise RuntimeError(\"commit called twice\")\n        self._reqs = dist.batch_isend_irecv(self._ops)\n\n    def wait(self):\n        if self._reqs is None:\n            raise RuntimeError(\"wait called before commit\")\n        for req in self._reqs:\n            req.wait()\n        self._reqs = None\n        self._ops = []\n\n\nclass GradientGather(torch.autograd.Function):\n    \n    @staticmethod\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return x\n\n    @staticmethod\n    def backward(ctx, dx):\n        dist.all_reduce(dx)\n        return dx\n\n\nclass RingProb(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, k, group):\n        rank = dist.get_rank()\n        k = k.contiguous()\n        comm = RingComm(group)\n\n        colle = [q, k]\n\n        lse = None\n        next_k = None\n        for step in range(comm.world_size):\n            if step + 1 != comm.world_size:\n                next_k: torch.Tensor = comm.send_recv(k)\n                comm.commit()\n\n            # vanilla lse\n            qk = torch.einsum(\"mhd,nhd->mn\", q, k)\n            block_lse = torch.log(torch.exp(qk).sum(dim=-1))\n\n            if step == 0:\n                lse = block_lse\n            else:\n                lse = lse - F.logsigmoid(lse - block_lse)\n\n            if step + 1 != comm.world_size:\n                comm.wait()\n                k = next_k\n\n        # this should be out_padded\n        colle.append(lse)\n        ctx.save_for_backward(*colle)\n        ctx.group = group\n        return lse\n\n    @staticmethod\n    def backward(ctx, dlse):\n        rank = dist.get_rank()\n        q, k, lse = ctx.saved_tensors\n        k_comm = RingComm(ctx.group)\n        d_k_comm = RingComm(ctx.group)\n        dq, dk = None, None\n        next_dk = None\n\n        block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)\n        block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)\n\n        next_dk, next_k = None, None\n\n        for step in range(k_comm.world_size):\n            if step + 1 != k_comm.world_size:\n                next_k = k_comm.send_recv(k)\n                k_comm.commit()\n\n            # vanilla gradient calculation\n            qk = torch.einsum(\"mhd,nhd->mn\", q, k)\n            qk_grad = torch.exp(qk - lse[:, None]).float()\n            qk_grad = qk_grad * dlse[:, None]\n            block_dq_buffer = torch.einsum(\"mn,nhd->mhd\", qk_grad, k.float())\n            block_dk_buffer = torch.einsum(\"nm,mhd->nhd\", qk_grad.T, q.float())\n\n            if step == 0:\n                dq = block_dq_buffer\n                dk = block_dk_buffer\n            else:\n                dq += block_dq_buffer\n                d_k_comm.wait()\n                dk = block_dk_buffer + next_dk\n\n            if step + 1 != k_comm.world_size:\n                k_comm.wait()\n                k = next_k\n\n            next_dk = d_k_comm.send_recv(dk)\n            d_k_comm.commit()\n\n        d_k_comm.wait()\n\n        return dq, next_dk, None\n    \n\nclass InfProb(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, q, k, group):\n        rank = dist.get_rank()\n        k = k.contiguous()\n        comm = RingComm(group)\n\n        colle = [q, k]\n\n        lse = None\n        next_k = None\n        for step in range(comm.world_size):\n            if step + 1 != comm.world_size:\n                next_k: torch.Tensor = comm.send_recv(k)\n                comm.commit()\n\n            # flash lse\n            block_lse = _flash_prob_forward(q, k)\n\n            if step == 0:\n                lse = block_lse\n            else:\n                lse = lse - F.logsigmoid(lse - block_lse)\n\n            if step + 1 != comm.world_size:\n                comm.wait()\n                k = next_k\n\n        # this should be out_padded\n        colle.append(lse)\n        ctx.save_for_backward(*colle)\n        ctx.group = group\n        return lse\n\n    @staticmethod\n    def backward(ctx, dlse):\n        rank = dist.get_rank()\n        q, k, lse = ctx.saved_tensors\n        k_comm = RingComm(ctx.group)\n        d_k_comm = RingComm(ctx.group)\n        dq, dk = None, None\n        next_dk = None\n\n        block_dq_buffer = torch.empty(q.shape, dtype=torch.float32, device=q.device)\n        block_dk_buffer = torch.empty(k.shape, dtype=torch.float32, device=k.device)\n\n        next_dk, next_k = None, None\n\n        for step in range(k_comm.world_size):\n            if step + 1 != k_comm.world_size:\n                next_k = k_comm.send_recv(k)\n                k_comm.commit()\n\n            # flash gradient calculation\n            block_dq_buffer, block_dk_buffer = _flash_prob_backward(q, k, lse, dlse)\n\n            if step == 0:\n                dq = block_dq_buffer\n                dk = block_dk_buffer\n            else:\n                dq += block_dq_buffer\n                d_k_comm.wait()\n                dk = block_dk_buffer + next_dk\n\n            if step + 1 != k_comm.world_size:\n                k_comm.wait()\n                k = next_k\n\n            next_dk = d_k_comm.send_recv(dk)\n            d_k_comm.commit()\n\n        d_k_comm.wait()\n\n        return dq, next_dk, None\n\n\ndef set_seed(rank, seed=42):\n    seed = rank + seed\n    random.seed(seed)             \n    torch.manual_seed(seed)      \n    torch.cuda.manual_seed(seed)  \n    torch.cuda.manual_seed_all(seed) \n\n\ndef _cal_ring_loss(q, k, labels, head_dim=256):\n    bq = q.shape[0]\n    bk = k.shape[0]\n    q = q.view(bq, -1, head_dim).float()\n    k = k.view(bk, -1, head_dim).float()\n\n    lse = RingProb.apply(q, k, None)\n    numerator = torch.einsum(\"mhd,mhd->m\", q, k[labels, ...])\n    loss = -numerator + lse\n\n    return loss\n\n\ndef _cal_inf_loss(q, k, labels, head_dim=256):\n    bq = q.shape[0]\n    bk = k.shape[0]\n    q = q.view(bq, -1, head_dim).float()\n    k = k.view(bk, -1, head_dim).float()\n\n    lse = InfProb.apply(q, k, None)\n    numerator = torch.einsum(\"mhd,mhd->m\", q, k[labels, ...])\n    loss = -numerator + lse\n\n    return loss\n\n\ndef cal_ring_loss(q, k, labels=None, scale=None, head_dim=256):\n    \"\"\"The triton implementation of the ring-cl.\n\n    Args:\n        q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].\n        k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].\n        labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.\n        scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.\n        head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.\n\n    \"\"\"\n\n    if labels is None:\n        labels = torch.arange(q.shape[0]).to(q.device)\n    if scale is None:\n        scale = 1.0\n    else:\n        scale = GradientGather.apply(scale)\n    if torch.distributed.is_initialized():\n        return _cal_ring_loss(scale * q, k, labels, head_dim).mean()\n    else:\n        return _cal_flash_loss(scale * q, k, labels, head_dim).mean()\n\n\ndef cal_inf_loss(q, k, labels=None, scale=None, head_dim=256):\n    \"\"\"The triton implementation of the inf-cl.\n\n    Args:\n        q (torch.Tensor): The column tensor in contrastive loss. The shape is [B, D].\n        k (torch.Tensor): The row tensor in contrastive loss. The shape is [B, D].\n        labels (torch.Tensor, optional): In CLIP loss, the labels are the indices of the positive pairs. The shape is [B]. When setting to None, the labels are the range of [0, B). Defaults to None.\n        scale (torch.Tensor, optional): The scale tensor of the query tensor. Defaults to None.\n        head_dim (int, optional): The head dimension. (must be 16, 32, 64, 128 or 256). Defaults to 256.\n\n    \"\"\"\n\n    if labels is None:\n        labels = torch.arange(q.shape[0]).to(q.device)\n    if scale is None:\n        scale = 1.0\n    else:\n        scale = GradientGather.apply(scale)\n    if torch.distributed.is_initialized():\n        return _cal_inf_loss(scale * q, k, labels, head_dim).mean()\n    else:\n        return _cal_flash_loss(scale * q, k, labels, head_dim).mean()\n\n\nif __name__ == \"__main__\":\n    import time\n\n    dist.init_process_group(\"nccl\")\n\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n\n    torch.cuda.set_device(f'cuda:{os.environ[\"LOCAL_RANK\"]}')\n\n    # Parameters\n    dtype = torch.float32\n    num_heads = 3        # Number of attention heads\n    seq_length_q = 32768 # Sequence length\n    seq_length_k = 32768\n    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)\n\n    # Randomly initialize inputs\n    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda\")\n    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda\")\n    l = torch.ones([], dtype=dtype, device=\"cuda\") * np.log(1 / 0.07); l = l.requires_grad_() # Logit scale\n\n    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query\n    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key\n\n    q1 = q.clone().detach().requires_grad_()\n    k1 = k.clone().detach().requires_grad_()\n    l1 = l.clone().detach().requires_grad_()\n\n    for i in range(1000):\n        # A. local torch gradient\n        start = time.time()\n        # A.1. gather q, k\n        gathered_q = [torch.zeros_like(q) for _ in range(world_size)]\n        gathered_k = [torch.zeros_like(k) for _ in range(world_size)]\n        dist.all_gather(gathered_q, q)\n        dist.all_gather(gathered_k, k)\n        gathered_q[rank] = q\n        gathered_k[rank] = k\n        all_q = torch.cat(gathered_q, dim=0)\n        all_k = torch.cat(gathered_k, dim=0)\n        # A.2. calculating qk logits\n        qk = torch.einsum(\"md,nd->mn\", l.exp() * all_q, all_k)\n        kq = qk.T\n        _labels = torch.arange(seq_length_q).to(q.device)\n        # A.3. calculating loss\n        loss_i2t = F.cross_entropy(qk, _labels, reduction=\"mean\")\n        loss_t2i = F.cross_entropy(kq, _labels, reduction=\"mean\")\n        # A.4. scaling loss to normal value\n        scale_factor = (all_q.shape[0] / q.shape[0])\n        loss = (loss_i2t + loss_t2i) * 0.5 * scale_factor\n        loss.backward()\n        show_loss = loss.detach().clone()\n        dist.all_reduce(show_loss)\n        show_loss = show_loss / (world_size * scale_factor)\n        end = time.time()\n\n        dist.barrier()\n\n        # B. triton implementation\n        start1 = time.time()\n        # labels = torch.arange(seq_length_q // world_size).to(q.device)\n        loss1_i2t = cal_inf_loss(q1, k1, scale=l1.exp())\n        loss1_t2i = cal_inf_loss(k1, q1, scale=l1.exp())\n        loss1 = (loss1_i2t + loss1_t2i).mean() * 0.5\n        loss1.backward()\n        end1 = time.time()\n\n        dist.barrier()\n\n        if rank == 0:\n            print(rank, end - start, end1 - start1, loss, show_loss, loss1)\n            print(l.grad, l1.grad, torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)))\n\n        q.grad = None; k.grad = None; l.grad = None\n        q1.grad = None; k1.grad = None; l1.grad = None\n"
  },
  {
    "path": "inf_clip/__init__.py",
    "content": "from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\n\nfrom .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss\nfrom .factory import list_models, add_model_config, get_model_config, load_checkpoint\n\nfrom .openai import load_openai_model, list_openai_models\nfrom .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \\\n    get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained\n\nfrom .models.tokenizer import SimpleTokenizer, tokenize, decode\nfrom .models.transform import image_transform, AugmentationCfg\nfrom .models.coca_arch import CoCa\nfrom .models.clip_arch import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \\\n    convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \\\n    get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg\nfrom .models.lit_arch import LiT\nfrom .models.loss import ClipLoss, DistillClipLoss, CoCaLoss\n\nfrom .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy\nfrom .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES\n"
  },
  {
    "path": "inf_clip/constants.py",
    "content": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\nIMAGENET_MEAN = (0.485, 0.456, 0.406)\nIMAGENET_STD = (0.229, 0.224, 0.225)\nINCEPTION_MEAN = (0.5, 0.5, 0.5)\nINCEPTION_STD = (0.5, 0.5, 0.5)\n"
  },
  {
    "path": "inf_clip/factory.py",
    "content": "import json\nimport logging\nimport os\nimport re\nfrom copy import deepcopy\nfrom dataclasses import asdict\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\n\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\n\nfrom .openai import load_openai_model\nfrom .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\\\n    list_pretrained_tags_by_model, download_pretrained_from_hf, convert_state_dict\n\nfrom .models.tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH\nfrom .models.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs\nfrom .models.clip_arch import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\\\n    resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg\nfrom .models.coca_arch import CoCa\nfrom .models.lit_arch import LiT\nfrom .models.loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss, FlashClipLoss, RingClipLoss, InfClipLoss, DiscoClipLoss\n\n\nHF_HUB_PREFIX = 'hf-hub:'\n_MODEL_CONFIG_PATHS = [Path(__file__).parent / f\"model_configs/\"]\n_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs\n\n\ndef _natural_key(string_):\n    return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\n\n\ndef _rescan_model_configs():\n    global _MODEL_CONFIGS\n\n    config_ext = ('.json',)\n    config_files = []\n    for config_path in _MODEL_CONFIG_PATHS:\n        if config_path.is_file() and config_path.suffix in config_ext:\n            config_files.append(config_path)\n        elif config_path.is_dir():\n            for ext in config_ext:\n                config_files.extend(config_path.glob(f'*{ext}'))\n\n    for cf in config_files:\n        with open(cf, 'r') as f:\n            model_cfg = json.load(f)\n            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):\n                _MODEL_CONFIGS[cf.stem] = model_cfg\n\n    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}\n\n\n_rescan_model_configs()  # initial populate of model config registry\n\n\ndef list_models():\n    \"\"\" enumerate available model architectures based on config files \"\"\"\n    return list(_MODEL_CONFIGS.keys())\n\n\ndef add_model_config(path):\n    \"\"\" add model config path or file and update registry \"\"\"\n    if not isinstance(path, Path):\n        path = Path(path)\n    _MODEL_CONFIG_PATHS.append(path)\n    _rescan_model_configs()\n\n\ndef get_model_config(model_name):\n    if model_name in _MODEL_CONFIGS:\n        return deepcopy(_MODEL_CONFIGS[model_name])\n    else:\n        return None\n\n\ndef _get_hf_config(model_id, cache_dir=None):\n    config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)\n    with open(config_path, 'r', encoding='utf-8') as f:\n        config = json.load(f)\n    return config\n\n\ndef get_tokenizer(\n        model_name: str = '',\n        context_length: Optional[int] = None,\n        **kwargs,\n):\n    if model_name.startswith(HF_HUB_PREFIX):\n        model_name = model_name[len(HF_HUB_PREFIX):]\n        try:\n            config = _get_hf_config(model_name)['model_cfg']\n        except Exception:\n            tokenizer = HFTokenizer(\n                model_name,\n                context_length=context_length or DEFAULT_CONTEXT_LENGTH,\n                **kwargs,\n            )\n            return tokenizer\n    else:\n        config = get_model_config(model_name)\n        assert config is not None, f\"No valid model config found for {model_name}.\"\n\n    text_config = config.get('text_cfg', {})\n    if 'tokenizer_kwargs' in text_config:\n        tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs)\n    else:\n        tokenizer_kwargs = kwargs\n\n    if context_length is None:\n        context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)\n\n    if 'hf_tokenizer_name' in text_config:\n        tokenizer = HFTokenizer(\n            text_config['hf_tokenizer_name'],\n            context_length=context_length,\n            **tokenizer_kwargs,\n        )\n    else:\n        tokenizer = SimpleTokenizer(\n            context_length=context_length,\n            **tokenizer_kwargs,\n        )\n\n    return tokenizer\n\n\ndef load_state_dict(checkpoint_path: str, map_location='cpu'):\n    checkpoint = torch.load(checkpoint_path, map_location=map_location)\n    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n        state_dict = checkpoint['state_dict']\n    elif isinstance(checkpoint, torch.jit.ScriptModule):\n        state_dict = checkpoint.state_dict()\n        for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n            state_dict.pop(key, None)\n    else:\n        state_dict = checkpoint\n    if next(iter(state_dict.items()))[0].startswith('module'):\n        state_dict = {k[7:]: v for k, v in state_dict.items()}\n    return state_dict\n\n\ndef load_checkpoint(\n        model: Union[CLIP, CustomTextCLIP],\n        checkpoint_path: str,\n        strict: bool = True,\n):\n    if Path(checkpoint_path).suffix in ('.npz', '.npy'):\n        # Separate path loading numpy big_vision (SigLIP) weights\n        from open_clip.pretrained import load_big_vision_weights\n        load_big_vision_weights(model, checkpoint_path)\n        return {}\n\n    state_dict = load_state_dict(checkpoint_path)\n\n    # Detect & convert 3rd party state_dicts -> open_clip\n    state_dict = convert_state_dict(model, state_dict)\n\n    # Detect old format and make compatible with new format\n    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):\n        state_dict = convert_to_custom_text_state_dict(state_dict)\n\n    # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712\n    if 'logit_bias' not in state_dict and model.logit_bias is not None:\n        state_dict[\"logit_bias\"] = torch.zeros_like(state_dict[\"logit_scale\"])\n\n    # Certain text transformers no longer expect position_ids after transformers==4.31\n    position_id_key = 'text.transformer.embeddings.position_ids'\n    if position_id_key in state_dict and not hasattr(model, position_id_key):\n        del state_dict[position_id_key]\n\n    resize_pos_embed(state_dict, model)\n    resize_text_pos_embed(state_dict, model)\n\n    # Finally, load the massaged state_dict into model\n    incompatible_keys = model.load_state_dict(state_dict, strict=strict)\n    return incompatible_keys\n\n\ndef create_model(\n        model_name: str,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_patch_dropout: Optional[float] = None,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        force_preprocess_cfg: Optional[Dict[str, Any]] = None,\n        pretrained_image: bool = False,\n        pretrained_hf: bool = False,\n        cache_dir: Optional[str] = None,\n        output_dict: Optional[bool] = None,\n        require_pretrained: bool = False,\n        **model_kwargs,\n):\n    force_preprocess_cfg = force_preprocess_cfg or {}\n    preprocess_cfg = asdict(PreprocessCfg())\n    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)\n    if has_hf_hub_prefix:\n        model_id = model_name[len(HF_HUB_PREFIX):]\n        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n        config = _get_hf_config(model_id, cache_dir)\n        preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])\n        model_cfg = config['model_cfg']\n        pretrained_hf = False  # override, no need to load original HF text weights\n    else:\n        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names\n        checkpoint_path = None\n        model_cfg = None\n\n    if isinstance(device, str):\n        device = torch.device(device)\n\n    if pretrained and pretrained.lower() == 'openai':\n        logging.info(f'Loading pretrained {model_name} from OpenAI.')\n        model = load_openai_model(\n            model_name,\n            precision=precision,\n            device=device,\n            cache_dir=cache_dir,\n        )\n    else:\n        model_cfg = model_cfg or get_model_config(model_name)\n        if model_cfg is not None:\n            logging.info(f'Loaded {model_name} model config.')\n        else:\n            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')\n            raise RuntimeError(f'Model config for {model_name} not found.')\n\n        if force_quick_gelu:\n            # override for use of QuickGELU on non-OpenAI transformer models\n            model_cfg[\"quick_gelu\"] = True\n\n        if force_patch_dropout is not None and force_patch_dropout != False:\n            # override the default patch dropout value\n            model_cfg[\"vision_cfg\"][\"patch_dropout\"] = force_patch_dropout\n\n        if force_image_size is not None and force_image_size != False:\n            # override model config's image size\n            model_cfg[\"vision_cfg\"][\"image_size\"] = force_image_size\n\n        is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})\n        if pretrained_image:\n            if is_timm_model:\n                # pretrained weight loading for timm models set via vision_cfg\n                model_cfg['vision_cfg']['timm_model_pretrained'] = True\n            else:\n                assert False, 'pretrained image towers currently only supported for timm models'\n\n        # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes\n        cast_dtype = get_cast_dtype(precision)\n        is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})\n        if is_hf_model:\n            # load pretrained weights for HF text model IFF no CLIP weights being loaded\n            # NOTE: disable pretrained_hf arguments.\n            # model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained\n            model_cfg['text_cfg']['hf_model_pretrained'] = model_cfg['text_cfg']['hf_model_pretrained']\n            # and not pretrained\n        custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model\n\n        model_cfg = dict(model_cfg, **model_kwargs)  # merge cfg dict w/ kwargs (kwargs overrides cfg)\n        model_arch = model_cfg.pop(\"arch\", \"CLIP\")\n        if custom_text:\n            if \"CLIP\" in model_arch:\n                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)\n            elif \"LiT\" in model_arch:\n                model = LiT(**model_cfg, cast_dtype=cast_dtype)\n            elif \"CoCa\" in model_arch or \"multimodal_cfg\" in model_cfg:\n                model = CoCa(**model_cfg, cast_dtype=cast_dtype)\n            else:\n                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)\n        else:\n            model = CLIP(**model_cfg, cast_dtype=cast_dtype)\n\n        if precision in (\"fp16\", \"bf16\"):\n            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n            # manual mixed precision that matches original OpenAI behaviour\n            if is_timm_model:\n                # FIXME this is a bit janky, create timm based model in low-precision and\n                # then cast only LayerNormFp32 instances back to float32 so they don't break.\n                # Why? The convert_weights_to_lp fn only works with native models.\n                model.to(device=device, dtype=dtype)\n                from .transformer import LayerNormFp32\n\n                def _convert_ln(m):\n                    if isinstance(m, LayerNormFp32):\n                        m.weight.data = m.weight.data.to(torch.float32)\n                        m.bias.data = m.bias.data.to(torch.float32)\n                model.apply(_convert_ln)\n            else:\n                model.to(device=device)\n                convert_weights_to_lp(model, dtype=dtype)\n        elif precision in (\"pure_fp16\", \"pure_bf16\"):\n            dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n            model.to(device=device, dtype=dtype)\n        else:\n            model.to(device=device)\n\n        pretrained_loaded = False\n        if pretrained:\n            checkpoint_path = ''\n            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)\n            if pretrained_cfg:\n                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)\n                preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)\n            elif os.path.exists(pretrained):\n                checkpoint_path = pretrained\n\n            if checkpoint_path:\n                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\n                load_checkpoint(model, checkpoint_path)\n            else:\n                error_str = (\n                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'\n                    f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')\n                logging.warning(error_str)\n                raise RuntimeError(error_str)\n            pretrained_loaded = True\n        elif has_hf_hub_prefix:\n            logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')\n            load_checkpoint(model, checkpoint_path)\n            pretrained_loaded = True\n\n        if require_pretrained and not pretrained_loaded:\n            # callers of create_model_from_pretrained always expect pretrained weights\n            raise RuntimeError(\n                f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')\n\n    if output_dict and hasattr(model, \"output_dict\"):\n        model.output_dict = True\n\n    if jit:\n        model = torch.jit.script(model)\n\n    # set image preprocessing configuration in model attributes for convenience\n    if getattr(model.visual, 'image_size', None) is not None:\n        # use image_size set on model creation (via config or force_image_size arg)\n        force_preprocess_cfg['size'] = model.visual.image_size\n    set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))\n\n    return model\n\n\ndef create_loss(args):\n    if args.distill_model:\n        return DistillClipLoss(\n            local_loss=args.local_loss,\n            gather_with_grad=args.gather_with_grad,\n            cache_labels=True,\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod,\n        )\n    elif \"coca\" in args.model.lower():\n        return CoCaLoss(\n            caption_loss_weight=args.coca_caption_loss_weight,\n            clip_loss_weight=args.coca_contrastive_loss_weight,\n            local_loss=args.local_loss,\n            gather_with_grad=args.gather_with_grad,\n            cache_labels=True,\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod,\n        )\n    elif args.siglip:\n        assert not args.horovod, \"Horovod not currently supported for SigLip\"\n        return SigLipLoss(\n            rank=args.rank,\n            world_size=args.world_size,\n        )\n    elif args.flashloss:\n        return FlashClipLoss(\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod,\n        )\n    elif args.ringloss:\n        return RingClipLoss(\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod\n        )\n    elif args.infloss:\n        return InfClipLoss(\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod\n        )\n    elif args.discoloss:\n        return DiscoClipLoss(\n            rank=args.rank,\n            world_size=args.world_size,\n            use_horovod=args.horovod\n        )\n    return ClipLoss(\n        local_loss=args.local_loss,\n        gather_with_grad=args.gather_with_grad,\n        cache_labels=True,\n        rank=args.rank,\n        world_size=args.world_size,\n        use_horovod=args.horovod,\n    )\n\n\ndef create_model_and_transforms(\n        model_name: str,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_patch_dropout: Optional[float] = None,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        image_mean: Optional[Tuple[float, ...]] = None,\n        image_std: Optional[Tuple[float, ...]] = None,\n        image_interpolation: Optional[str] = None,\n        image_resize_mode: Optional[str] = None,  # only effective for inference\n        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n        pretrained_image: bool = False,\n        pretrained_hf: bool = False,\n        cache_dir: Optional[str] = None,\n        output_dict: Optional[bool] = None,\n        **model_kwargs,\n):\n    force_preprocess_cfg = merge_preprocess_kwargs(\n        {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)\n\n    model = create_model(\n        model_name,\n        pretrained,\n        precision=precision,\n        device=device,\n        jit=jit,\n        force_quick_gelu=force_quick_gelu,\n        force_custom_text=force_custom_text,\n        force_patch_dropout=force_patch_dropout,\n        force_image_size=force_image_size,\n        force_preprocess_cfg=force_preprocess_cfg,\n        pretrained_image=pretrained_image,\n        pretrained_hf=pretrained_hf,\n        cache_dir=cache_dir,\n        output_dict=output_dict,\n        **model_kwargs,\n    )\n\n    pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)\n\n    preprocess_train = image_transform_v2(\n        pp_cfg,\n        is_train=True,\n        aug_cfg=aug_cfg,\n    )\n    preprocess_val = image_transform_v2(\n        pp_cfg,\n        is_train=False,\n    )\n\n    return model, preprocess_train, preprocess_val\n\n\ndef create_model_from_pretrained(\n        model_name: str,\n        pretrained: Optional[str] = None,\n        precision: str = 'fp32',\n        device: Union[str, torch.device] = 'cpu',\n        jit: bool = False,\n        force_quick_gelu: bool = False,\n        force_custom_text: bool = False,\n        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n        image_mean: Optional[Tuple[float, ...]] = None,\n        image_std: Optional[Tuple[float, ...]] = None,\n        image_interpolation: Optional[str] = None,\n        image_resize_mode: Optional[str] = None,  # only effective for inference\n        return_transform: bool = True,\n        cache_dir: Optional[str] = None,\n        **model_kwargs,\n):\n    force_preprocess_cfg = merge_preprocess_kwargs(\n        {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)\n\n    model = create_model(\n        model_name,\n        pretrained,\n        precision=precision,\n        device=device,\n        jit=jit,\n        force_quick_gelu=force_quick_gelu,\n        force_custom_text=force_custom_text,\n        force_image_size=force_image_size,\n        force_preprocess_cfg=force_preprocess_cfg,\n        cache_dir=cache_dir,\n        require_pretrained=True,\n        **model_kwargs,\n    )\n\n    if not return_transform:\n        return model\n\n    preprocess = image_transform_v2(\n        PreprocessCfg(**model.visual.preprocess_cfg),\n        is_train=False,\n    )\n\n    return model, preprocess\n"
  },
  {
    "path": "inf_clip/model_configs/EVA01-g-14-plus.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva_giant_patch14_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA01-g-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva_giant_patch14_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA02-B-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_base_patch16_clip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA02-E-14-plus.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_enormous_patch14_clip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 32\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA02-E-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_enormous_patch14_clip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA02-L-14-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"timm_model_name\": \"eva02_large_patch14_clip_336\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/EVA02-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"eva02_large_patch14_clip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/LiT-B-16.json",
    "content": "{\n    \"arch\": \"LiT-B-16\",\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"vit_base_patch16_224\",\n        \"timm_model_pretrained\": true,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": \"linear\"\n    },\n    \"text_cfg\": {\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"hf_model_name\": \"bert-base-uncased\",\n        \"hf_model_pretrained\": true,\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/LiT-B-32.json",
    "content": "{\n    \"arch\": \"LiT-B-32\",\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"vit_base_patch32_224\",\n        \"timm_model_pretrained\": true,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": \"linear\"\n    },\n    \"text_cfg\": {\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"hf_model_name\": \"bert-base-uncased\",\n        \"hf_model_pretrained\": true,\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/LiT-L-16.json",
    "content": "{\n    \"arch\": \"LiT-L-16\",\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"vit_large_patch16_224\",\n        \"timm_model_pretrained\": true,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": \"linear\"\n    },\n    \"text_cfg\": {\n        \"hf_tokenizer_name\": \"bert-large-uncased\",\n        \"hf_model_name\": \"bert-large-uncased\",\n        \"hf_model_pretrained\": true,\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-B.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_base_mci_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"token\",\n        \"timm_proj\": null,\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.0,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"no_causal_mask\": false\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-S1.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"fastvit_mci1\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"avg\",\n        \"timm_proj\": null,\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.0,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"no_causal_mask\": true\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/MobileCLIP-S2.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"fastvit_mci2\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"avg\",\n        \"timm_proj\": null,\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.0,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"no_causal_mask\": true\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/RN101-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/RN101.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            23,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/RN50-quickgelu.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/RN50.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,\n            6,\n            3\n        ],\n        \"width\": 64,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/RN50x16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"layers\": [\n            6,\n            8,\n            18,\n            8\n        ],\n        \"width\": 96,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/RN50x4.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": [\n            4,\n            6,\n            10,\n            6\n        ],\n        \"width\": 80,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/RN50x64.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 448,\n        \"layers\": [\n            3,\n            15,\n            36,\n            10\n        ],\n        \"width\": 128,\n        \"patch_size\": null\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-256.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"timm_model_name\": \"vit_base_patch16_siglip_256\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-384.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"timm_model_name\": \"vit_base_patch16_siglip_384\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-512.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 512,\n        \"timm_model_name\": \"vit_base_patch16_siglip_512\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP-i18n-256.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"timm_model_name\": \"vit_base_patch16_siglip_256\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 250000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP-i18n-256\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-SigLIP.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"vit_base_patch16_siglip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-plus-240.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 240,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-plus.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-256.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-plus-256.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 896,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32-quickgelu.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-378-quickgelu.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 378,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-CLIPA-336.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-CLIPA.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14-quickgelu.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-H-16.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-280.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 280,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-CLIPA-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-CLIPA.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14-quickgelu.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-320.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 320,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-SigLIP-256.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"timm_model_name\": \"vit_large_patch16_siglip_256\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16-SigLIP-384.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"timm_model_name\": \"vit_large_patch16_siglip_384\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-L-16.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-M-16-alt.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 16,\n        \"ls_init_value\": 1e-4\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-M-16.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-M-32-alt.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-M-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-S-16-alt.json",
    "content": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 256,\n        \"heads\": 4,\n        \"layers\": 10\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-S-16.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 16\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-S-32-alt.json",
    "content": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 256,\n        \"heads\": 4,\n        \"layers\": 10\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-S-32.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 384,\n        \"heads\": 6,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP-384.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"timm_model_name\": \"vit_so400m_patch14_siglip_384\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 64,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 1152,\n        \"heads\": 16,\n        \"layers\": 27,\n        \"mlp_ratio\": 3.7362,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-SO400M-14-SigLIP.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"init_logit_bias\": -10,\n    \"custom_text\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"timm_model_name\": \"vit_so400m_patch14_siglip_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"context_length\": 16,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"timm/ViT-B-16-SigLIP\",\n        \"tokenizer_kwargs\": {\n            \"clean\": \"canonicalize\"\n        },\n        \"width\": 1152,\n        \"heads\": 16,\n        \"layers\": 27,\n        \"mlp_ratio\": 3.7362,\n        \"no_causal_mask\": true,\n        \"proj_bias\": true,\n        \"pool_type\": \"last\",\n        \"norm_kwargs\":{\n            \"eps\": 1e-6\n        }\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14-CLIPA-336.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 48,\n        \"width\": 1664,\n        \"head_width\": 104,\n        \"mlp_ratio\": 4.9231,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 32,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14-CLIPA.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n        \"head_width\": 104,\n        \"mlp_ratio\": 4.9231,\n        \"patch_size\": 14,\n        \"no_ln_pre\": true,\n        \"pool_type\": \"avg\",\n        \"final_ln_after_pool\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 32,\n        \"vocab_size\": 32000,\n        \"hf_tokenizer_name\": \"bert-base-uncased\",\n        \"tokenizer_kwargs\": {\n            \"strip_sep_token\": true\n        },\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 32,\n        \"pool_type\": \"last\",\n        \"no_causal_mask\": true\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-bigG-14.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n        \"head_width\": 104,\n        \"mlp_ratio\": 4.9231,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 32\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-e-14.json",
    "content": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 56,\n        \"width\": 1792,\n        \"head_width\": 112,\n        \"mlp_ratio\": 8.5715,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1280,\n        \"heads\": 20,\n        \"layers\": 36\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViT-g-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 40,\n        \"width\": 1408,\n        \"head_width\": 88,\n        \"mlp_ratio\": 4.3637,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-B-LTT.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_base_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 768,\n      \"heads\": 12,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-B.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_base_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 512,\n      \"heads\": 8,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L-256.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_256\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 256\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 768,\n      \"heads\": 12,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L-336.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_336\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 336\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 768,\n      \"heads\": 12,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 768,\n      \"heads\": 12,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2-256.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_256\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 256\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1024,\n      \"heads\": 16,\n      \"layers\": 24\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2-336.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_336\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 336\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1024,\n      \"heads\": 16,\n      \"layers\": 24\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-L2.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_large2_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1024,\n      \"heads\": 16,\n      \"layers\": 24\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-S-LTT.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_small_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 768,\n      \"heads\": 12,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-S.json",
    "content": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_small_224\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 224\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 384,\n      \"heads\": 6,\n      \"layers\": 12\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-256.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_256\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 256\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1152,\n      \"heads\": 16,\n      \"layers\": 27\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-336.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_336\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 336\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1152,\n      \"heads\": 16,\n      \"layers\": 27\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/ViTamin-XL-384.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"vision_cfg\": {\n      \"timm_model_name\": \"vitamin_xlarge_384\",\n      \"timm_model_pretrained\": false,\n      \"timm_pool\": \"\",\n      \"timm_proj\": \"linear\",\n      \"timm_drop\": 0.0,\n      \"timm_drop_path\": 0.1,\n      \"image_size\": 256\n    },\n    \"text_cfg\": {\n      \"context_length\": 77,\n      \"vocab_size\": 49408,\n      \"width\": 1152,\n      \"heads\": 16,\n      \"layers\": 27\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/coca_ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32,\n        \"attentional_pool\": true,\n        \"attn_pooler_heads\": 8,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12,\n        \"attn_pooler_heads\": 8\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/coca_ViT-L-14.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n        \"patch_size\": 14,\n        \"attentional_pool\": true,\n        \"attn_pooler_heads\": 8,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12,\n        \"attn_pooler_heads\": 12\n    },\n    \"custom_text\": true\n}\n"
  },
  {
    "path": "inf_clip/model_configs/coca_base.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"multimodal_cfg\": {\n        \"width\": 768,\n        \"context_length\": 76,\n        \"vocab_size\": 64000,\n        \"mlp_ratio\": 4,\n        \"layers\": 12,\n        \"dim_head\": 64,\n        \"heads\": 12,\n        \"n_queries\": 256,\n        \"attn_pooler_heads\": 8\n    },\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 18,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"context_length\": 76,\n        \"vocab_size\": 64000,\n        \"layers\": 12,\n        \"heads\": 12,\n        \"width\": 768,\n        \"embed_cls\": true,\n        \"output_tokens\": true\n    },\n    \"custom_text\": true\n}"
  },
  {
    "path": "inf_clip/model_configs/coca_roberta-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32,\n        \"output_tokens\": true\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"roberta-base\",\n        \"hf_tokenizer_name\": \"roberta-base\",\n        \"hf_proj_type\": \"linear\",\n        \"width\": 768,\n        \"output_tokens\": true\n    },\n    \"multimodal_cfg\": {\n        \"context_length\": 76,\n        \"width\": 768,\n        \"heads\": 8,\n        \"layers\": 12\n    },\n    \"custom_text\": true\n}\n"
  },
  {
    "path": "inf_clip/model_configs/convnext_base.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_base_w.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_base_w_320.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_large.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_large_d.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"mlp\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 16\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_large_d_320.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"mlp\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 768,\n        \"heads\": 12,\n        \"layers\": 16\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_small.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_small\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_tiny.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_tiny\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_xlarge.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 20\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_xxlarge.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/convnext_xxlarge_320.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"timm_drop\": 0.0,\n        \"timm_drop_path\": 0.1,\n        \"image_size\": 320\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 1024,\n        \"heads\": 16,\n        \"layers\": 24\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/mt5-base-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"google/mt5-base\",\n        \"hf_tokenizer_name\": \"google/mt5-base\",\n        \"hf_pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/mt5-xl-ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"google/mt5-xl\",\n        \"hf_tokenizer_name\": \"google/mt5-xl\",\n        \"hf_pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-base-siglip.json",
    "content": "{\n    \"embed_dim\": 768,\n    \"custom_text\": true,\n    \"init_logit_bias\": -10,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"timm_model_name\": \"vit_base_patch16_siglip_384\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"facebook/nllb-200-distilled-600M\",\n        \"hf_tokenizer_name\": \"facebook/nllb-200-distilled-600M\",\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-base.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"facebook/nllb-200-distilled-600M\",\n        \"hf_tokenizer_name\": \"facebook/nllb-200-distilled-600M\",\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-large-siglip.json",
    "content": "{\n    \"embed_dim\": 1152,\n    \"custom_text\": true,\n    \"init_logit_bias\": -10,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"timm_model_name\": \"vit_so400m_patch14_siglip_384\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"map\",\n        \"timm_proj\": \"none\"\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"facebook/nllb-200-distilled-1.3B\",\n        \"hf_tokenizer_name\": \"facebook/nllb-200-distilled-1.3B\",\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/nllb-clip-large.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"facebook/nllb-200-distilled-1.3B\",\n        \"hf_tokenizer_name\": \"facebook/nllb-200-distilled-1.3B\",\n        \"hf_proj_type\": \"linear\",\n        \"hf_pooler_type\": \"cls_pooler\"\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/roberta-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"roberta-base\",\n        \"hf_tokenizer_name\": \"roberta-base\",\n        \"hf_pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/swin_base_patch4_window7_224.json",
    "content": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"swin_base_patch4_window7_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 640,\n        \"heads\": 10,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/vit_medium_patch16_gap_256.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_medium_patch16_gap_256\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 256\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/vit_relpos_medium_patch16_cls_224.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_relpos_medium_patch16_cls_224\",\n        \"timm_model_pretrained\": false,\n        \"timm_pool\": \"\",\n        \"timm_proj\": \"linear\",\n        \"image_size\": 224\n    },\n    \"text_cfg\": {\n        \"context_length\": 77,\n        \"vocab_size\": 49408,\n        \"width\": 512,\n        \"heads\": 8,\n        \"layers\": 12\n    }\n}"
  },
  {
    "path": "inf_clip/model_configs/xlm-roberta-base-ViT-B-32.json",
    "content": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n        \"patch_size\": 32\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"xlm-roberta-base\",\n        \"hf_tokenizer_name\": \"xlm-roberta-base\",\n        \"hf_pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "inf_clip/model_configs/xlm-roberta-large-ViT-H-14.json",
    "content": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n        \"head_width\": 80,\n        \"patch_size\": 14\n    },\n    \"text_cfg\": {\n        \"hf_model_name\": \"xlm-roberta-large\",\n        \"hf_tokenizer_name\": \"xlm-roberta-large\",\n        \"hf_pooler_type\": \"mean_pooler\"\n    }\n}\n"
  },
  {
    "path": "inf_clip/models/clip_arch.py",
    "content": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimport copy\nimport logging\nimport math\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.utils.checkpoint import checkpoint\nfrom functools import partial\n\nfrom .hf_model import HFTextEncoder\nfrom .modified_resnet import ModifiedResNet\nfrom .timm_model import TimmModel\nfrom .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\\\n    text_global_pool\nfrom ..utils import to_2tuple\n\n\n@dataclass\nclass CLIPVisionCfg:\n    layers: Union[Tuple[int, int, int, int], int] = 12\n    width: int = 768\n    head_width: int = 64\n    mlp_ratio: float = 4.0\n    patch_size: int = 16\n    image_size: Union[Tuple[int, int], int] = 224\n\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\n    attentional_pool: bool = False  # whether to use attentional pooler in the last embedding layer (overrides pool_type)\n    attn_pooler_queries: int = 256  # n_queries for attentional pooler\n    attn_pooler_heads: int = 8  # n heads for attentional_pooling\n    no_ln_pre: bool = False  # disable pre transformer LayerNorm\n    pos_embed_type: str = 'learnable'\n    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling\n    pool_type: str = 'tok'\n    output_tokens: bool = False\n    act_kwargs: Optional[dict] = None\n    norm_kwargs: Optional[dict] = None\n\n    timm_model_name: Optional[str] = None  # a valid model name overrides layers, width, patch_size\n    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model\n    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')\n    timm_proj_bias: bool = False  # enable bias final projection\n    timm_drop: float = 0.  # head dropout\n    timm_drop_path: Optional[float] = None  # backbone stochastic depth\n\n\n@dataclass\nclass CLIPTextCfg:\n    context_length: int = 77\n    vocab_size: int = 49408\n    hf_tokenizer_name: Optional[str] = None\n    tokenizer_kwargs: Optional[dict] = None\n\n    width: int = 512\n    heads: int = 8\n    layers: int = 12\n    mlp_ratio: float = 4.0\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    embed_cls: bool = False\n    pad_id: int = 0\n    no_causal_mask: bool = False  # disable causal masking\n    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling\n    pool_type: str = 'argmax'\n    proj_bias: bool = False\n    output_tokens: bool = False\n    act_kwargs: dict = None\n    norm_kwargs: dict = None\n\n    # HuggingFace specific text tower config\n    hf_model_name: Optional[str] = None\n    hf_model_pretrained: bool = True\n    hf_proj_type: str = 'mlp'\n    hf_pooler_type: str = 'mean_pooler'  # attentional pooling for HF models\n\n\ndef get_cast_dtype(precision: str):\n    cast_dtype = None\n    if precision == 'bf16':\n        cast_dtype = torch.bfloat16\n    elif precision == 'fp16':\n        cast_dtype = torch.float16\n    return cast_dtype\n\n\ndef get_input_dtype(precision: str):\n    input_dtype = None\n    if precision in ('bf16', 'pure_bf16'):\n        input_dtype = torch.bfloat16\n    elif precision in ('fp16', 'pure_fp16'):\n        input_dtype = torch.float16\n    return input_dtype\n\n\ndef _build_vision_tower(\n        embed_dim: int,\n        vision_cfg: CLIPVisionCfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None\n):\n    if isinstance(vision_cfg, dict):\n        vision_cfg = CLIPVisionCfg(**vision_cfg)\n\n    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\n    # memory efficient in recent PyTorch releases (>= 1.10).\n    # NOTE: timm models always use native GELU regardless of quick_gelu flag.\n    act_layer = QuickGELU if quick_gelu else nn.GELU\n\n    if vision_cfg.timm_model_name:\n        visual = TimmModel(\n            vision_cfg.timm_model_name,\n            pretrained=vision_cfg.timm_model_pretrained,\n            pool=vision_cfg.timm_pool,\n            proj=vision_cfg.timm_proj,\n            proj_bias=vision_cfg.timm_proj_bias,\n            drop=vision_cfg.timm_drop,\n            drop_path=vision_cfg.timm_drop_path,\n            patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,\n            embed_dim=embed_dim,\n            image_size=vision_cfg.image_size,\n        )\n    elif isinstance(vision_cfg.layers, (tuple, list)):\n        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width\n        visual = ModifiedResNet(\n            layers=vision_cfg.layers,\n            output_dim=embed_dim,\n            heads=vision_heads,\n            image_size=vision_cfg.image_size,\n            width=vision_cfg.width,\n        )\n    else:\n        vision_heads = vision_cfg.width // vision_cfg.head_width\n        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n        if vision_cfg.norm_kwargs:\n            norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)\n        if vision_cfg.act_kwargs is not None:\n            act_layer = partial(act_layer, **vision_cfg.act_kwargs)\n\n        visual = VisionTransformer(\n            image_size=vision_cfg.image_size,\n            patch_size=vision_cfg.patch_size,\n            width=vision_cfg.width,\n            layers=vision_cfg.layers,\n            heads=vision_heads,\n            mlp_ratio=vision_cfg.mlp_ratio,\n            ls_init_value=vision_cfg.ls_init_value,\n            patch_dropout=vision_cfg.patch_dropout,\n            attentional_pool=vision_cfg.attentional_pool,\n            attn_pooler_queries=vision_cfg.attn_pooler_queries,\n            attn_pooler_heads=vision_cfg.attn_pooler_heads,\n            pos_embed_type=vision_cfg.pos_embed_type,\n            no_ln_pre=vision_cfg.no_ln_pre,\n            final_ln_after_pool=vision_cfg.final_ln_after_pool,\n            pool_type=vision_cfg.pool_type,\n            output_tokens=vision_cfg.output_tokens,\n            output_dim=embed_dim,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n\n    return visual\n\n\ndef _build_text_tower(\n        embed_dim: int,\n        text_cfg: CLIPTextCfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None,\n):\n    if isinstance(text_cfg, dict):\n        text_cfg = CLIPTextCfg(**text_cfg)\n\n    if text_cfg.hf_model_name:\n        text = HFTextEncoder(\n            text_cfg.hf_model_name,\n            output_dim=embed_dim,\n            proj_type=text_cfg.hf_proj_type,\n            pooler_type=text_cfg.hf_pooler_type,\n            pretrained=text_cfg.hf_model_pretrained,\n            output_tokens=text_cfg.output_tokens,\n        )\n    else:\n        act_layer = QuickGELU if quick_gelu else nn.GELU\n        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n        if text_cfg.norm_kwargs:\n            norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)\n        if text_cfg.act_kwargs is not None:\n            act_layer = partial(act_layer, **text_cfg.act_kwargs)\n\n        text = TextTransformer(\n            context_length=text_cfg.context_length,\n            vocab_size=text_cfg.vocab_size,\n            width=text_cfg.width,\n            heads=text_cfg.heads,\n            layers=text_cfg.layers,\n            mlp_ratio=text_cfg.mlp_ratio,\n            ls_init_value=text_cfg.ls_init_value,\n            output_dim=embed_dim,\n            embed_cls=text_cfg.embed_cls,\n            no_causal_mask=text_cfg.no_causal_mask,\n            pad_id=text_cfg.pad_id,\n            pool_type=text_cfg.pool_type,\n            proj_bias=text_cfg.proj_bias,\n            output_tokens=text_cfg.output_tokens,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n    return text\n\n\nclass CLIP(nn.Module):\n    output_dict: torch.jit.Final[bool]\n    arch_type: torch.jit.Final[str] = 'clip'\n\n    def __init__(\n            self,\n            embed_dim: int,\n            vision_cfg: CLIPVisionCfg,\n            text_cfg: CLIPTextCfg,\n            quick_gelu: bool = False,\n            init_logit_scale: float = np.log(1 / 0.07),\n            init_logit_bias: Optional[float] = None,\n            cast_dtype: Optional[torch.dtype] = None,\n            output_dict: bool = False,\n    ):\n        super().__init__()\n        self.output_dict = output_dict\n\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n\n        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n        self.transformer = text.transformer\n        self.context_length = text.context_length\n        self.vocab_size = text.vocab_size\n        self.token_embedding = text.token_embedding\n        self.positional_embedding = text.positional_embedding\n        self.ln_final = text.ln_final\n        self.text_projection = text.text_projection\n        self.text_pool_type = text.pool_type\n        self.register_buffer('attn_mask', text.attn_mask, persistent=False)\n\n        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)\n        if init_logit_bias is not None:\n            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)\n        else:\n            self.logit_bias = None\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.transformer.grad_checkpointing = enable\n\n    def encode_image(self, image, normalize: bool = False):\n        features = self.visual(image)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def encode_text(self, text, normalize: bool = False):\n        cast_dtype = self.transformer.get_cast_dtype()\n\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\n\n        x = x + self.positional_embedding.to(cast_dtype)\n        x = self.transformer(x, attn_mask=self.attn_mask)\n        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]\n        x, _ = text_global_pool(x, text, self.text_pool_type)\n        if self.text_projection is not None:\n            if isinstance(self.text_projection, nn.Linear):\n                x = self.text_projection(x)\n            else:\n                x = x @ self.text_projection\n\n        return F.normalize(x, dim=-1) if normalize else x\n\n    def get_logits(self, image, text):\n        image_features = self.encode_image(image, normalize=True)\n        text_features = self.encode_text(text, normalize=True)\n        image_logits = self.logit_scale.exp() * image_features @ text_features.T\n        if self.logit_bias is not None:\n            image_logits += self.logit_bias\n        text_logits = image_logits.T\n        return image_logits, text_logits\n\n    def forward(\n            self,\n            image: Optional[torch.Tensor] = None,\n            text: Optional[torch.Tensor] = None,\n    ):\n        image_features = self.encode_image(image, normalize=True) if image is not None else None\n        text_features = self.encode_text(text, normalize=True) if text is not None else None\n\n        if self.output_dict:\n            out_dict = {\n                \"image_features\": image_features,\n                \"text_features\": text_features,\n                \"logit_scale\": self.logit_scale.exp()\n            }\n            if self.logit_bias is not None:\n                out_dict['logit_bias'] = self.logit_bias\n            return out_dict\n\n        if self.logit_bias is not None:\n            return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n        return image_features, text_features, self.logit_scale.exp()\n\n\nclass CustomTextCLIP(nn.Module):\n    output_dict: torch.jit.Final[bool]\n    arch_type: torch.jit.Final[str] = 'clip'\n\n    def __init__(\n            self,\n            embed_dim: int,\n            vision_cfg: CLIPVisionCfg,\n            text_cfg: CLIPTextCfg,\n            quick_gelu: bool = False,\n            init_logit_scale: float = np.log(1 / 0.07),\n            init_logit_bias: Optional[float] = None,\n            cast_dtype: Optional[torch.dtype] = None,\n            output_dict: bool = False,\n    ):\n        super().__init__()\n        self.output_dict = output_dict\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n        self.context_length = self.text.context_length\n        self.vocab_size = self.text.vocab_size\n        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)\n        if init_logit_bias is not None:\n            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)\n        else:\n            self.logit_bias = None\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n\n    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n        self.text.lock(unlocked_layers, freeze_layer_norm)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.text.set_grad_checkpointing(enable)\n\n    def encode_image(self, image, normalize: bool = False):\n        features = self.visual(image)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def encode_text(self, text, normalize: bool = False):\n        features = self.text(text)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def get_logits(self, image, text):\n        image_features = self.encode_image(image, normalize=True)\n        text_features = self.encode_text(text, normalize=True)\n        image_logits = self.logit_scale.exp() * image_features @ text_features.T\n        if self.logit_bias is not None:\n            image_logits += self.logit_bias\n        text_logits = image_logits.T\n        return image_logits, text_logits\n\n    def forward(\n            self,\n            image: Optional[torch.Tensor] = None,\n            text: Optional[torch.Tensor] = None,\n    ):\n        image_features = self.encode_image(image, normalize=True) if image is not None else None\n        text_features = self.encode_text(text, normalize=True) if text is not None else None\n\n        if self.output_dict:\n            out_dict = {\n                \"image_features\": image_features,\n                \"text_features\": text_features,\n                \"logit_scale\": self.logit_scale.exp()\n            }\n            if self.logit_bias is not None:\n                out_dict['logit_bias'] = self.logit_bias\n            return out_dict\n\n        if self.logit_bias is not None:\n            return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n        return image_features, text_features, self.logit_scale.exp()\n\n\ndef convert_weights_to_lp(model: nn.Module, dtype=torch.float16):\n    \"\"\"Convert applicable model parameters to low-precision (bf16 or fp16)\"\"\"\n\n    def _convert_weights(l):\n        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n            l.weight.data = l.weight.data.to(dtype)\n            if l.bias is not None:\n                l.bias.data = l.bias.data.to(dtype)\n\n        if isinstance(l, (nn.MultiheadAttention, Attention)):\n            for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n                tensor = getattr(l, attr)\n                if tensor is not None:\n                    tensor.data = tensor.data.to(dtype)\n\n        if isinstance(l, (CLIP, TextTransformer)):\n            # convert text nn.Parameter projections\n            attr = getattr(l, \"text_projection\", None)\n            if attr is not None:\n                attr.data = attr.data.to(dtype)\n\n        if isinstance(l, VisionTransformer):\n            # convert vision nn.Parameter projections\n            attr = getattr(l, \"proj\", None)\n            if attr is not None:\n                attr.data = attr.data.to(dtype)\n\n    model.apply(_convert_weights)\n\n\nconvert_weights_to_fp16 = convert_weights_to_lp  # backwards compat\n\n\n# used to maintain checkpoint compatibility\ndef convert_to_custom_text_state_dict(state_dict: dict):\n    if 'text_projection' in state_dict:\n        # old format state_dict, move text tower -> .text\n        new_state_dict = {}\n        for k, v in state_dict.items():\n            if any(k.startswith(p) for p in (\n                'text_projection',\n                'positional_embedding',\n                'token_embedding',\n                'transformer',\n                'ln_final',\n            )):\n                k = 'text.' + k\n            new_state_dict[k] = v\n        return new_state_dict\n    return state_dict\n\n\ndef build_model_from_openai_state_dict(\n        state_dict: dict,\n        quick_gelu=True,\n        cast_dtype=torch.float16,\n):\n    vit = \"visual.proj\" in state_dict\n\n    if vit:\n        vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n        vision_layers = len(\n            [k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n        vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n        grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n        image_size = vision_patch_size * grid_size\n    else:\n        counts: list = [\n            len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n        vision_layers = tuple(counts)\n        vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n        output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n        vision_patch_size = None\n        assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n        image_size = output_width * 32\n\n    embed_dim = state_dict[\"text_projection\"].shape[1]\n    context_length = state_dict[\"positional_embedding\"].shape[0]\n    vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n    transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n    transformer_heads = transformer_width // 64\n    transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n\n    vision_cfg = CLIPVisionCfg(\n        layers=vision_layers,\n        width=vision_width,\n        patch_size=vision_patch_size,\n        image_size=image_size,\n    )\n    text_cfg = CLIPTextCfg(\n        context_length=context_length,\n        vocab_size=vocab_size,\n        width=transformer_width,\n        heads=transformer_heads,\n        layers=transformer_layers,\n    )\n    model = CLIP(\n        embed_dim,\n        vision_cfg=vision_cfg,\n        text_cfg=text_cfg,\n        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU\n        cast_dtype=cast_dtype,\n    )\n\n    for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n        state_dict.pop(key, None)\n    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16\n    model.load_state_dict(state_dict)\n    return model.eval()\n\n\ndef trace_model(model, batch_size=256, device=torch.device('cpu')):\n    model.eval()\n    image_size = model.visual.image_size\n    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\n    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)\n    model = torch.jit.trace_module(\n        model,\n        inputs=dict(\n            forward=(example_images, example_text),\n            encode_text=(example_text,),\n            encode_image=(example_images,)\n        ))\n    model.visual.image_size = image_size\n    return model\n\n\ndef resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):\n    # Rescale the grid of position embeddings when loading from state_dict\n    old_pos_embed = state_dict.get('visual.positional_embedding', None)\n    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\n        return\n    grid_size = to_2tuple(model.visual.grid_size)\n    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)\n    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\n    if new_seq_len == old_pos_embed.shape[0]:\n        return\n\n    if extra_tokens:\n        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\n    else:\n        pos_emb_tok, pos_emb_img = None, old_pos_embed\n    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\n\n    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\n    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\n    pos_emb_img = F.interpolate(\n        pos_emb_img,\n        size=grid_size,\n        mode=interpolation,\n        antialias=antialias,\n        align_corners=False,\n    )\n    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\n    if pos_emb_tok is not None:\n        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\n    else:\n        new_pos_embed = pos_emb_img\n    state_dict['visual.positional_embedding'] = new_pos_embed\n\n\ndef resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):\n    old_pos_embed = state_dict.get('positional_embedding', None)\n    if old_pos_embed is None:\n        return\n    # FIXME add support for text cls_token\n    model_pos_embed = getattr(model, 'positional_embedding', None)\n    if model_pos_embed is None:\n        model_pos_embed = getattr(model.text, 'positional_embedding', None)\n\n    old_num_pos = old_pos_embed.shape[0]\n    old_width = old_pos_embed.shape[1]\n    num_pos = model_pos_embed.shape[0]\n    width = model_pos_embed.shape[1]\n    assert old_width == width, 'text pos_embed width changed!'\n    if old_num_pos == num_pos:\n        return\n\n    logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)\n    old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)\n    old_pos_embed = F.interpolate(\n        old_pos_embed,\n        size=num_pos,\n        mode=interpolation,\n        antialias=antialias,\n        align_corners=False,\n    )\n    old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]\n    new_pos_embed = old_pos_embed\n\n    state_dict['positional_embedding'] = new_pos_embed\n\n\ndef get_model_preprocess_cfg(model):\n    module = getattr(model, 'visual', model)\n    preprocess_cfg = getattr(module, 'preprocess_cfg', {})\n    if not preprocess_cfg:\n        # use separate legacy attributes if preprocess_cfg dict not found\n        size = getattr(module, 'image_size')\n        if size is not None:\n            preprocess_cfg['size'] = size\n        mean = getattr(module, 'image_mean', None)\n        if mean is not None:\n            preprocess_cfg['mean'] = mean\n        std = getattr(module, 'image_std', None)\n        if std is not None:\n            preprocess_cfg['std'] = std\n    return preprocess_cfg\n\n\ndef set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):\n    module = getattr(model, 'visual', model)\n    module.image_mean = preprocess_cfg['mean']  # legacy attribute, keeping for bwd compat\n    module.image_std = preprocess_cfg['std']  # legacy attribute, keeping for bwd compat\n    module.preprocess_cfg = copy.deepcopy(preprocess_cfg)  # new attr, package all pp cfg as dict\n\n\ndef get_model_tokenize_cfg(model):\n    module = getattr(model, 'text', model)\n    cfg = {}\n    context_length = getattr(module, 'context_length', None)\n    if context_length is not None:\n        cfg['context_length'] = context_length\n    vocab_size = getattr(module, 'vocab_size', None)\n    if vocab_size is not None:\n        cfg['vocab_size'] = vocab_size\n    return cfg"
  },
  {
    "path": "inf_clip/models/coca_arch.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nfrom dataclasses import dataclass\n\nfrom .transformer import (\n    LayerNormFp32,\n    LayerNorm,\n    QuickGELU,\n    MultimodalTransformer,\n)\nfrom .clip_arch import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower\n\ntry:\n    from transformers import (\n        BeamSearchScorer,\n        LogitsProcessorList,\n        TopPLogitsWarper,\n        TopKLogitsWarper,\n        RepetitionPenaltyLogitsProcessor,\n        MinLengthLogitsProcessor,\n        MaxLengthCriteria,\n        StopStringCriteria,\n        EosTokenCriteria,\n        StoppingCriteriaList\n    )\n\n    GENERATION_TYPES = {\n        \"top_k\": TopKLogitsWarper,\n        \"top_p\": TopPLogitsWarper,\n        \"beam_search\": \"beam_search\"\n    }\n    _has_transformers = True\nexcept ImportError as e:\n    GENERATION_TYPES = {\n        \"top_k\": None,\n        \"top_p\": None,\n        \"beam_search\": \"beam_search\"\n    }\n    _has_transformers = False\n\n\n@dataclass\nclass MultimodalCfg(CLIPTextCfg):\n    mlp_ratio: int = 4\n    dim_head: int = 64\n    heads: int = 8\n    n_queries: int = 256\n    attn_pooler_heads: int = 8\n\n\ndef _build_text_decoder_tower(\n        embed_dim,\n        multimodal_cfg,\n        quick_gelu: bool = False,\n        cast_dtype: Optional[torch.dtype] = None,\n):\n    multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n    act_layer = QuickGELU if quick_gelu else nn.GELU\n    norm_layer = (\n        LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n    )\n\n    decoder = MultimodalTransformer(\n        context_length=multimodal_cfg.context_length,\n        width=multimodal_cfg.width,\n        heads=multimodal_cfg.heads,\n        layers=multimodal_cfg.layers,\n        ls_init_value=multimodal_cfg.ls_init_value,\n        output_dim=embed_dim,\n        act_layer=act_layer,\n        norm_layer=norm_layer,\n    )\n\n    return decoder\n\n\ndef _token_to_tensor(token_id, device: str = \"cpu\") -> torch.Tensor:\n    if not isinstance(token_id, torch.Tensor):\n        if isinstance(token_id, int):\n            token_id = [token_id]\n        token_id = torch.tensor(token_id, device=device)\n    return token_id\n\n\nclass CoCa(nn.Module):\n    arch_type: torch.jit.Final[str] = 'coca'\n    \n    def __init__(\n            self,\n            embed_dim,\n            multimodal_cfg: MultimodalCfg,\n            text_cfg: CLIPTextCfg,\n            vision_cfg: CLIPVisionCfg,\n            quick_gelu: bool = False,\n            init_logit_scale: float = np.log(1 / 0.07),\n            init_logit_bias: Optional[float] = None,\n            cast_dtype: Optional[torch.dtype] = None,\n            pad_id: int = 0,\n    ):\n        super().__init__()\n        multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n        text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg\n        vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg\n\n        self.text = _build_text_tower(\n            embed_dim=embed_dim,\n            text_cfg=text_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        vocab_size = (\n            text_cfg.vocab_size  # for hf models\n            if hasattr(text_cfg, \"hf_model_name\") and text_cfg.hf_model_name is not None\n            else text_cfg.vocab_size\n        )\n\n        self.visual = _build_vision_tower(\n            embed_dim=embed_dim,\n            vision_cfg=vision_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        self.text_decoder = _build_text_decoder_tower(\n            vocab_size,\n            multimodal_cfg=multimodal_cfg,\n            quick_gelu=quick_gelu,\n            cast_dtype=cast_dtype,\n        )\n\n        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)\n        if init_logit_bias is not None:\n            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)\n        else:\n            self.logit_bias = None\n        self.pad_id = pad_id\n\n        self.context_length = multimodal_cfg.context_length\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable: bool = True):\n        self.visual.set_grad_checkpointing(enable)\n        self.text.set_grad_checkpointing(enable)\n        self.text_decoder.set_grad_checkpointing(enable)\n\n    def _encode_image(self, images, normalize: bool = True):\n        image_latent, tokens_embs = self.visual(images)\n        image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent\n        return image_latent, tokens_embs\n\n    def _encode_text(self, text, normalize: bool = True):\n        text_latent, token_emb = self.text(text)\n        text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent\n        return text_latent, token_emb\n\n    def encode_image(self, images, normalize: bool = True):\n        image_latent, _ = self._encode_image(images, normalize=normalize)\n        return image_latent\n\n    def encode_text(self, text, normalize: bool = True):\n        text_latent, _ = self._encode_text(text, normalize=normalize)\n        return text_latent\n\n    def forward(\n            self,\n            image,\n            text: Optional[torch.Tensor] = None,\n            image_latent: Optional[torch.Tensor] = None,\n            image_embs: Optional[torch.Tensor] = None,\n            output_labels: bool = True,\n    ):\n        if image_latent is None or image_embs is None:\n            image_latent, image_embs = self._encode_image(image)\n\n        if text is None:\n            return {\"image_features\": image_latent, \"image_embs\": image_embs}\n\n        text_latent, token_embs = self._encode_text(text)\n\n        # FIXME this isn't an ideal solution, would like to improve -RW\n        labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None\n        if output_labels:\n            # align text_embs and thus logits with labels for teacher-forcing caption loss\n            token_embs = token_embs[:, :-1]\n\n        logits = self.text_decoder(image_embs, token_embs)\n        out_dict = {\n            \"image_features\": image_latent,\n            \"text_features\": text_latent,\n            \"logits\": logits,\n            \"logit_scale\": self.logit_scale.exp()\n        }\n        if labels is not None:\n            out_dict[\"labels\"] = labels\n        if self.logit_bias is not None:\n            out_dict[\"logit_bias\"] = self.logit_bias\n        return out_dict\n\n    def generate(\n        self,\n        image,\n        text=None,\n        seq_len=30,\n        max_seq_len=77,\n        temperature=1.,\n        generation_type=\"beam_search\",\n        top_p=0.1,  # keep tokens in the 1 - top_p quantile\n        top_k=1,  # keeps the top_k most probable tokens\n        pad_token_id=None,\n        eos_token_id=None,\n        sot_token_id=None,\n        num_beams=6,\n        num_beam_groups=3,\n        min_seq_len=5,\n        stopping_criteria=None,\n        repetition_penalty=1.0,\n        fixed_output_length=False # if True output.shape == (batch_size, seq_len)\n    ):\n        # taking many ideas and components from HuggingFace GenerationMixin\n        # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation\n        assert _has_transformers, \"Please install transformers for generate functionality. `pip install transformers`.\"\n        assert seq_len > min_seq_len, \"seq_len must be larger than min_seq_len\"\n        device = image.device\n\n        with torch.no_grad():\n            sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)\n            eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)\n            pad_token_id = self.pad_id if pad_token_id is None else pad_token_id\n            logit_processor = LogitsProcessorList(\n                [\n                    MinLengthLogitsProcessor(min_seq_len, eos_token_id),\n                    RepetitionPenaltyLogitsProcessor(repetition_penalty),\n                ]\n            )\n\n            if stopping_criteria is None:\n                stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]\n            stopping_criteria = StoppingCriteriaList(stopping_criteria)\n\n            if generation_type == \"beam_search\":\n                output = self._generate_beamsearch(\n                    image_inputs=image,\n                    pad_token_id=pad_token_id,\n                    eos_token_id=eos_token_id,\n                    sot_token_id=sot_token_id,\n                    num_beams=num_beams,\n                    num_beam_groups=num_beam_groups,\n                    min_seq_len=min_seq_len,\n                    stopping_criteria=stopping_criteria,\n                    logit_processor=logit_processor,\n                )\n                if fixed_output_length and output.shape[1] < seq_len:\n                    pad_len = seq_len - output.shape[1]\n                    return torch.cat((\n                            output,\n                            torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id\n                        ),\n                        dim=1\n                    )\n                return output\n\n            elif generation_type == \"top_p\":\n                logit_warper = GENERATION_TYPES[generation_type](top_p)\n            elif generation_type == \"top_k\":\n                logit_warper = GENERATION_TYPES[generation_type](top_k)\n            else:\n                raise ValueError(\n                    f\"generation_type has to be one of \"\n                    f\"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}.\"\n                )\n\n            image_latent, image_embs = self._encode_image(image)\n\n            if text is None:\n                text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id\n\n            was_training = self.training\n            num_dims = len(text.shape)\n\n            if num_dims == 1:\n                text = text[None, :]\n\n            self.eval()\n            out = text\n\n            while True:\n                x = out[:, -max_seq_len:]\n                cur_len = x.shape[1]\n                logits = self(\n                    image,\n                    x,\n                    image_latent=image_latent,\n                    image_embs=image_embs,\n                    output_labels=False,\n                )[\"logits\"][:, -1]\n                mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)\n                sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id\n\n                if mask.all():\n                    if not fixed_output_length:\n                        break\n                else:\n                    logits = logits[~mask, :]\n                    filtered_logits = logit_processor(x[~mask, :], logits)\n                    filtered_logits = logit_warper(x[~mask, :], filtered_logits)\n                    probs = F.softmax(filtered_logits / temperature, dim=-1)\n\n                    if (cur_len + 1 == seq_len):\n                        sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id\n                    else:\n                        sample[~mask, :] = torch.multinomial(probs, 1)\n\n                out = torch.cat((out, sample), dim=-1)\n\n                cur_len += 1\n\n                if all(stopping_criteria(out, None)):\n                    break\n\n            if num_dims == 1:\n                out = out.squeeze(0)\n\n            self.train(was_training)\n            return out\n\n    def _generate_beamsearch(\n            self,\n            image_inputs,\n            pad_token_id=None,\n            eos_token_id=None,\n            sot_token_id=None,\n            num_beams=6,\n            num_beam_groups=3,\n            min_seq_len=5,\n            stopping_criteria=None,\n            logit_processor=None,\n            logit_warper=None,\n    ):\n        device = image_inputs.device\n        batch_size = image_inputs.shape[0]\n        image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)\n        image_latent, image_embs = self._encode_image(image_inputs)\n\n        input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)\n        input_ids = input_ids * sot_token_id\n        beam_scorer = BeamSearchScorer(\n            batch_size=batch_size,\n            num_beams=num_beams,\n            device=device,\n            num_beam_groups=num_beam_groups,\n        )\n        # instantiate logits processors\n        logits_processor = (\n            LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])\n            if logit_processor is None\n            else logit_processor\n        )\n\n        num_beams = beam_scorer.num_beams\n        num_beam_groups = beam_scorer.num_beam_groups\n        num_sub_beams = num_beams // num_beam_groups\n        batch_size = len(beam_scorer._beam_hyps) // num_beam_groups\n        batch_beam_size, cur_len = input_ids.shape\n        beam_indices = None\n\n        if num_beams * batch_size != batch_beam_size:\n            raise ValueError(\n                f\"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}.\"\n            )\n\n        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)\n        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in\n        # the same group don't produce same tokens everytime.\n        beam_scores[:, ::num_sub_beams] = 0\n        beam_scores = beam_scores.view((batch_size * num_beams,))\n\n        while True:\n\n            # predicted tokens in cur_len step\n            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)\n\n            # indices which will form the beams in the next time step\n            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)\n\n            # do one decoder step on all beams of all sentences in batch\n            model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)\n            outputs = self(\n                model_inputs['images'],\n                model_inputs['text'],\n                image_latent=image_latent,\n                image_embs=image_embs,\n                output_labels=False,\n            )\n\n            for beam_group_idx in range(num_beam_groups):\n                group_start_idx = beam_group_idx * num_sub_beams\n                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)\n                group_size = group_end_idx - group_start_idx\n\n                # indices of beams of current group among all sentences in batch\n                batch_group_indices = []\n\n                for batch_idx in range(batch_size):\n                    batch_group_indices.extend(\n                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]\n                    )\n                group_input_ids = input_ids[batch_group_indices]\n\n                # select outputs of beams of currentg group only\n                next_token_logits = outputs['logits'][batch_group_indices, -1, :]\n                vocab_size = next_token_logits.shape[-1]\n\n                next_token_scores_processed = logits_processor(\n                    group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx\n                )\n                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)\n                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)\n\n                # reshape for beam search\n                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)\n\n                next_token_scores, next_tokens = torch.topk(\n                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True\n                )\n\n                next_indices = torch.div(next_tokens, vocab_size, rounding_mode=\"floor\")\n                next_tokens = next_tokens % vocab_size\n\n                # stateless\n                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n                beam_outputs = beam_scorer.process(\n                    group_input_ids,\n                    next_token_scores,\n                    next_tokens,\n                    next_indices,\n                    pad_token_id=pad_token_id,\n                    eos_token_id=eos_token_id,\n                    beam_indices=process_beam_indices,\n                    group_index=beam_group_idx,\n                )\n                beam_scores[batch_group_indices] = beam_outputs[\"next_beam_scores\"]\n                beam_next_tokens = beam_outputs[\"next_beam_tokens\"]\n                beam_idx = beam_outputs[\"next_beam_indices\"]\n\n                input_ids[batch_group_indices] = group_input_ids[beam_idx]\n                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)\n                current_tokens[batch_group_indices] = group_input_ids[:, -1]\n\n                # (beam_idx // group_size) -> batch_idx\n                # (beam_idx % group_size) -> offset of idx inside the group\n                reordering_indices[batch_group_indices] = (\n                    num_beams * torch.div(beam_idx, group_size, rounding_mode=\"floor\") + group_start_idx + (beam_idx % group_size)\n                )\n\n            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)\n\n            # increase cur_len\n            cur_len = cur_len + 1\n            if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):\n                break\n\n        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n        sequence_outputs = beam_scorer.finalize(\n            input_ids,\n            beam_scores,\n            next_tokens,\n            next_indices,\n            pad_token_id=pad_token_id,\n            eos_token_id=eos_token_id,\n            max_length=stopping_criteria.max_length,\n            beam_indices=final_beam_indices,\n        )\n        return sequence_outputs['sequences']\n\n\ndef prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):\n    if past:\n        input_ids = input_ids[:, -1].unsqueeze(-1)\n\n    attention_mask = kwargs.get(\"attention_mask\", None)\n    position_ids = kwargs.get(\"position_ids\", None)\n\n    if attention_mask is not None and position_ids is None:\n        # create position_ids on the fly for batch generation\n        position_ids = attention_mask.long().cumsum(-1) - 1\n        position_ids.masked_fill_(attention_mask == 0, 1)\n    else:\n        position_ids = None\n    return {\n        \"text\": input_ids,\n        \"images\": image_inputs,\n        \"past_key_values\": past,\n        \"position_ids\": position_ids,\n        \"attention_mask\": attention_mask,\n    }\n"
  },
  {
    "path": "inf_clip/models/hf_configs.py",
    "content": "# HF architecture dict:\narch_dict = {\n    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n    \"roberta\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"hidden_size\",\n            \"heads\": \"num_attention_heads\",\n            \"layers\": \"num_hidden_layers\",\n            \"layer_attr\": \"layer\",\n            \"token_embeddings_attr\": \"embeddings\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig\n    \"xlm-roberta\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"hidden_size\",\n            \"heads\": \"num_attention_heads\",\n            \"layers\": \"num_hidden_layers\",\n            \"layer_attr\": \"layer\",\n            \"token_embeddings_attr\": \"embeddings\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/mt5#mt5\n    \"mt5\": {\n        \"config_names\": {\n            # unlimited seqlen\n            # https://github.com/google-research/text-to-text-transfer-transformer/issues/273\n            # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374\n            \"context_length\": \"\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"d_model\",\n            \"heads\": \"num_heads\",\n            \"layers\": \"num_layers\",\n            \"layer_attr\": \"block\",\n            \"token_embeddings_attr\": \"embed_tokens\"\n        },\n        \"pooler\": \"mean_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/bert\n    \"bert\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"hidden_size\",\n            \"heads\": \"num_attention_heads\",\n            \"layers\": \"num_hidden_layers\",\n        },\n        \"pooler\": \"cls_pooler\",\n    },\n    # https://huggingface.co/docs/transformers/model_doc/m2m_100\n    \"m2m_100\": {\n        \"config_names\": {\n            \"context_length\": \"max_position_embeddings\",\n            \"vocab_size\": \"vocab_size\",\n            \"width\": \"d_model\",\n            \"heads\": \"encoder_attention_heads\",\n            \"layers\": \"encoder_layers\",\n        },\n        \"pooler\": \"cls_pooler\",\n    },\n}\n"
  },
  {
    "path": "inf_clip/models/hf_model.py",
    "content": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.\n\"\"\"\nimport re\nfrom contextlib import nullcontext\n\nimport torch\nimport torch.nn as nn\nfrom torch import TensorType\n\ntry:\n    import transformers\n    from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig\n    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \\\n        BaseModelOutputWithPoolingAndCrossAttentions\n    from transformers.modeling_utils import no_init_weights\nexcept ImportError as e:\n    transformers = None\n\n\n    class BaseModelOutput:\n        pass\n\n\n    class PretrainedConfig:\n        pass\n\nfrom .hf_configs import arch_dict\n\n\n# utils\ndef _camel2snake(s):\n    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()\n\n\n# TODO: ?last - for gpt-like models\n_POOLERS = {}\n\n\ndef register_pooler(cls):\n    \"\"\"Decorator registering pooler class\"\"\"\n    _POOLERS[_camel2snake(cls.__name__)] = cls\n    return cls\n\n\n@register_pooler\nclass MeanPooler(nn.Module):\n    \"\"\"Mean pooling\"\"\"\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)\n        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)\n\n\n@register_pooler\nclass MaxPooler(nn.Module):\n    \"\"\"Max pooling\"\"\"\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)\n        return masked_output.max(1).values\n\n\n@register_pooler\nclass ClsPooler(nn.Module):\n    \"\"\"CLS token pooling\"\"\"\n\n    def __init__(self, use_pooler_output=True):\n        super().__init__()\n        self.cls_token_position = 0\n        self.use_pooler_output = use_pooler_output\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        if (self.use_pooler_output and\n            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and\n            (x.pooler_output is not None)\n        ):\n            return x.pooler_output\n\n        return x.last_hidden_state[:, self.cls_token_position, :]\n\n\n@register_pooler\nclass ClsLastHiddenStatePooler(nn.Module):\n    \"\"\"CLS token pooling\n    NOTE: this is equivalent to ClsPooler above with use_pooler_output=False\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.cls_token_position = 0\n\n    def forward(self, x: BaseModelOutput, attention_mask: TensorType):\n        return x.last_hidden_state[:, self.cls_token_position, :]\n\n\nclass HFTextEncoder(nn.Module):\n    \"\"\"HuggingFace model adapter\"\"\"\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            model_name_or_path: str,\n            output_dim: int,\n            pooler_type: str = None,\n            proj_type: str = None,\n            pretrained: bool = True,\n            output_tokens: bool = False,\n    ):\n        super().__init__()\n        self.output_tokens = output_tokens\n        self.output_dim = output_dim\n\n        # TODO: find better way to get this information\n        uses_transformer_pooler = (pooler_type == \"cls_pooler\")\n\n        if transformers is None:\n            raise RuntimeError(\"Please `pip install transformers` to use pre-trained HuggingFace models\")\n\n        self.config = AutoConfig.from_pretrained(model_name_or_path)\n        # FIXME: Gradient Accumulation can't fully resume the dropout state, so we close dropout here.\n        # self.config.attention_probs_dropout_prob = 0.0  # Disable dropout\n        # self.config.hidden_dropout_prob = 0.0  # Disable dropout\n        # Enable sdpa attention\n        self.config._attn_implementation = 'sdpa'\n        # initialization of the model is really slow (https://github.com/huggingface/transformers/issues/9205#issuecomment-748741195)\n        # FIXME: To speed up the initialization of the model, we only load pretrained weights and \n        # disable the torch initialization of the weights.\n        if pretrained:\n            context = no_init_weights\n        else:\n            context = nullcontext\n\n        with context():\n            # TODO: do all model configs have this attribute? PretrainedConfig does so yes??\n            if hasattr(self.config, \"is_encoder_decoder\") and self.config.is_encoder_decoder:\n                self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True)\n                self.transformer = self.transformer.encoder\n            else:\n                self.transformer = AutoModel.from_pretrained(model_name_or_path, config=self.config, use_safetensors=True, add_pooling_layer=uses_transformer_pooler)\n\n        if pooler_type is None:  # get default arch pooler\n            pooler_type = (arch_dict[self.config.model_type][\"pooler\"])\n\n        # FIXME downstream users of OpenCLIP models use these attr, need to verify valid across all models\n        self.vocab_size = getattr(self.config, 'vocab_size', 0)\n        self.context_length = getattr(self.config, 'max_position_embeddings', 0)\n\n        self.pooler = _POOLERS[pooler_type]()\n\n        d_model = getattr(self.config, arch_dict[self.config.model_type][\"config_names\"][\"width\"])\n\n        if (d_model == output_dim) and (proj_type is None):  # do we always need a proj?\n            self.proj = nn.Identity()\n        elif proj_type == 'linear':\n            self.proj = nn.Linear(d_model, output_dim, bias=False)\n        elif proj_type == 'mlp':\n            hidden_size = (d_model + output_dim) // 2\n            self.proj = nn.Sequential(\n                nn.Linear(d_model, hidden_size, bias=False),\n                nn.GELU(),\n                nn.Linear(hidden_size, output_dim, bias=False),\n            )\n\n    def forward(self, x: TensorType):\n        attn_mask = (x != self.config.pad_token_id).long()\n        out = self.transformer(input_ids=x, attention_mask=attn_mask)\n        pooled_out = self.pooler(out, attn_mask)\n        projected = self.proj(pooled_out)\n\n        seq_len = out.last_hidden_state.shape[1]\n        tokens = (\n            out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] \n            if type(self.pooler) == ClsPooler \n            else out.last_hidden_state\n        )\n        \n        if self.output_tokens:\n            return projected, tokens\n        return projected\n\n    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n        if not unlocked_layers:  # full freezing\n            for n, p in self.transformer.named_parameters():\n                p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\n            return\n\n        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer\n        layer_list = getattr(encoder, arch_dict[self.config.model_type][\"config_names\"][\"layer_attr\"])\n        print(f\"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model\")\n        embeddings = getattr(\n            self.transformer, arch_dict[self.config.model_type][\"config_names\"][\"token_embeddings_attr\"])\n        modules = [embeddings, *layer_list][:-unlocked_layers]\n        # freeze layers\n        for module in modules:\n            for n, p in module.named_parameters():\n                p.requires_grad = (not freeze_layer_norm) if \"LayerNorm\" in n.split(\".\") else False\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.gradient_checkpointing_enable()\n\n    def init_parameters(self):\n        pass\n"
  },
  {
    "path": "inf_clip/models/lit_arch.py",
    "content": "from dataclasses import dataclass\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as torch_checkpoint\n\nfrom .clip_arch import _build_vision_tower, _build_text_tower\n\n\n@dataclass\nclass LiTVisionCfg:\n    layers: Union[Tuple[int, int, int, int], int] = 12\n    width: int = 768\n    head_width: int = 64\n    mlp_ratio: float = 4.0\n    patch_size: int = 16\n    image_size: Union[Tuple[int, int], int] = 224\n\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\n    attentional_pool: bool = False  # whether to use attentional pooler in the last embedding layer (overrides pool_type)\n    attn_pooler_queries: int = 256  # n_queries for attentional pooler\n    attn_pooler_heads: int = 8  # n heads for attentional_pooling\n    no_ln_pre: bool = False  # disable pre transformer LayerNorm\n    pos_embed_type: str = 'learnable'\n    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling\n    pool_type: str = 'tok'\n    output_tokens: bool = False\n    act_kwargs: Optional[dict] = None\n    norm_kwargs: Optional[dict] = None\n\n    timm_model_name: Optional[str] = None  # a valid model name overrides layers, width, patch_size\n    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model\n    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')\n    timm_proj_bias: bool = False  # enable bias final projection\n    timm_drop: float = 0.  # head dropout\n    timm_drop_path: Optional[float] = None  # backbone stochastic depth\n\n\n@dataclass\nclass LiTTextCfg:\n    context_length: int = 77\n    vocab_size: int = 49408\n    hf_tokenizer_name: Optional[str] = None\n    tokenizer_kwargs: Optional[dict] = None\n\n    width: int = 512\n    heads: int = 8\n    layers: int = 12\n    mlp_ratio: float = 4.0\n    ls_init_value: Optional[float] = None  # layer scale initial value\n    embed_cls: bool = False\n    pad_id: int = 0\n    no_causal_mask: bool = False  # disable causal masking\n    final_ln_after_pool: bool = False  # apply final LayerNorm after pooling\n    pool_type: str = 'argmax'\n    proj_bias: bool = False\n    output_tokens: bool = False\n    act_kwargs: dict = None\n    norm_kwargs: dict = None\n\n    # HuggingFace specific text tower config\n    hf_model_name: Optional[str] = None\n    hf_model_pretrained: bool = True\n    hf_proj_type: str = 'mlp'\n    hf_pooler_type: str = 'mean_pooler'  # attentional pooling for HF models\n\n\nclass LiT(nn.Module):\n    output_dict: torch.jit.Final[bool]\n    arch_type: torch.jit.Final[str] = 'lit'\n\n    def __init__(\n            self,\n            embed_dim: int,\n            vision_cfg: LiTVisionCfg,\n            text_cfg: LiTTextCfg,\n            quick_gelu: bool = False,\n            init_logit_scale: float = np.log(1 / 0.07),\n            init_logit_bias: Optional[float] = None,\n            cast_dtype: Optional[torch.dtype] = None,\n            output_dict: bool = False,\n    ):\n        super().__init__()\n        self.output_dict = output_dict\n        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n        self.context_length = self.text.context_length\n        self.vocab_size = self.text.vocab_size\n        self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale)\n        if init_logit_bias is not None:\n            self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias)\n        else:\n            self.logit_bias = None\n\n        self.embed_dim = embed_dim\n        self.lock_image_tower()\n\n    def get_embed_dim(self):\n        return self.embed_dim\n\n    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n\n    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n        self.text.lock(unlocked_layers, freeze_layer_norm)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.visual.set_grad_checkpointing(enable)\n        self.text.set_grad_checkpointing(enable)\n\n    def encode_image(self, image, normalize: bool = False):\n        features = self.visual(image)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def encode_trunk_image(self, image, normalize: bool = False):\n        trunk_features = self.visual.forward_trunk(image)\n        features = self.visual.forward_head(trunk_features)\n        return trunk_features, F.normalize(features, dim=-1) if normalize else features\n\n    def project_image(self, trunk_features, normalize: bool = False):\n        features = self.visual.head(trunk_features)\n        return trunk_features, F.normalize(features, dim=-1) if normalize else features\n\n    def encode_text(self, text, normalize: bool = False):\n        features = self.text(text)\n        return F.normalize(features, dim=-1) if normalize else features\n\n    def get_logits(self, image, text):\n        image_features = self.encode_image(image, normalize=True)\n        text_features = self.encode_text(text, normalize=True)\n        image_logits = self.logit_scale.exp() * image_features @ text_features.T\n        if self.logit_bias is not None:\n            image_logits += self.logit_bias\n        text_logits = image_logits.T\n        return image_logits, text_logits\n\n    def forward(\n            self,\n            image: Optional[torch.Tensor] = None,\n            text: Optional[torch.Tensor] = None,\n            project_only: Optional[bool] = False,\n    ):\n        if project_only:\n            image_trunk_features, image_features = self.project_image(image, normalize=True) if image is not None else None\n        else:\n            image_trunk_features, image_features = self.encode_trunk_image(image, normalize=True) if image is not None else None\n        text_features = self.encode_text(text, normalize=True) if text is not None else None\n\n        if self.output_dict:\n            out_dict = {\n                \"image_trunk_features\": image_trunk_features,\n                \"image_features\": image_features,\n                \"text_features\": text_features,\n                \"logit_scale\": self.logit_scale.exp()\n            }\n            if self.logit_bias is not None:\n                out_dict['logit_bias'] = self.logit_bias\n            return out_dict\n\n        if self.logit_bias is not None:\n            return image_trunk_features, image_features, text_features, self.logit_scale.exp(), self.logit_bias\n        return image_trunk_features, image_features, text_features, self.logit_scale.exp()\n"
  },
  {
    "path": "inf_clip/models/loss.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ntry:\n    import torch.distributed.nn\n    from torch import distributed as dist\n\n    has_distributed = True\nexcept ImportError:\n    has_distributed = False\n\ntry:\n    import horovod.torch as hvd\nexcept ImportError:\n    hvd = None\n\nfrom inf_cl import cal_flash_loss, cal_ring_loss, cal_inf_loss\n\n\ndef gather_features(\n        image_features,\n        text_features,\n        local_loss=False,\n        gather_with_grad=False,\n        rank=0,\n        world_size=1,\n        use_horovod=False\n):\n    if world_size == 1:\n        return image_features, text_features\n    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'\n    if use_horovod:\n        assert hvd is not None, 'Please install horovod'\n        if gather_with_grad:\n            all_image_features = hvd.allgather(image_features)\n            all_text_features = hvd.allgather(text_features)\n        else:\n            with torch.no_grad():\n                all_image_features = hvd.allgather(image_features)\n                all_text_features = hvd.allgather(text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features = list(all_image_features.chunk(world_size, dim=0))\n                gathered_text_features = list(all_text_features.chunk(world_size, dim=0))\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n                all_image_features = torch.cat(gathered_image_features, dim=0)\n                all_text_features = torch.cat(gathered_text_features, dim=0)\n    else:\n        # We gather tensors from all gpus\n        if gather_with_grad:\n            all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)\n            all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)\n        else:\n            gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]\n            gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]\n            dist.all_gather(gathered_image_features, image_features)\n            dist.all_gather(gathered_text_features, text_features)\n            if not local_loss:\n                # ensure grads for local rank when all_* features don't have a gradient\n                gathered_image_features[rank] = image_features\n                gathered_text_features[rank] = text_features\n            all_image_features = torch.cat(gathered_image_features, dim=0)\n            all_text_features = torch.cat(gathered_text_features, dim=0)\n\n    return all_image_features, all_text_features\n\n\ndef all_reduce(tensor):\n    if not dist.is_available():\n        return tensor\n    else:\n        world_size = dist.get_world_size()\n        dist.all_reduce(tensor)\n        return tensor\n\n\nclass ClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            local_loss=False,\n            gather_with_grad=False,\n            cache_labels=False,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.local_loss = local_loss\n        self.gather_with_grad = gather_with_grad\n        self.cache_labels = cache_labels\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def get_ground_truth(self, device, num_logits) -> torch.Tensor:\n        # calculated ground-truth and cache if enabled\n        if self.prev_num_logits != num_logits or device not in self.labels:\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\n            if self.world_size > 1 and self.local_loss:\n                labels = labels + num_logits * self.rank\n            if self.cache_labels:\n                self.labels[device] = labels\n                self.prev_num_logits = num_logits\n        else:\n            labels = self.labels[device]\n        return labels\n\n    def get_logits(self, image_features, text_features, logit_scale):\n        if self.world_size > 1:\n            all_image_features, all_text_features = gather_features(\n                image_features, text_features,\n                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)\n\n            if self.local_loss:\n                logits_per_image = logit_scale * image_features @ all_text_features.T\n                logits_per_text = logit_scale * text_features @ all_image_features.T\n            else:\n                logits_per_image = logit_scale * all_image_features @ all_text_features.T\n                logits_per_text = logits_per_image.T\n        else:\n            logits_per_image = logit_scale * image_features @ text_features.T\n            logits_per_text = logit_scale * text_features @ image_features.T\n        \n        return logits_per_image, logits_per_text\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)\n\n        labels = self.get_ground_truth(device, logits_per_image.shape[0])\n\n        total_loss = (\n            F.cross_entropy(logits_per_image, labels, reduction='none') +\n            F.cross_entropy(logits_per_text, labels, reduction='none')\n        ) / 2\n\n        scale_factor = (total_loss.shape[0] / image_features.shape[0])\n        total_loss = torch.mean(total_loss * scale_factor)\n\n        show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)\n\n        return {\"contrastive_loss\": total_loss, \"show_loss\": show_loss}\n\n\nclass DiscoClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def get_ground_truth(self, device, num_logits) -> torch.Tensor:\n        # calculated ground-truth and cache if enabled\n        if self.prev_num_logits != num_logits or device not in self.labels:\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\n            labels = labels + num_logits * self.rank\n            self.labels[device] = labels\n            self.prev_num_logits = num_logits\n        else:\n            labels = self.labels[device]\n        return labels\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n\n        all_image_features, all_text_features = gather_features(\n            image_features, text_features,\n            True, True, self.rank, self.world_size, self.use_horovod)\n\n        logits_per_image = logit_scale * image_features @ all_text_features.T\n        logits_per_text = logit_scale * text_features @ all_image_features.T\n\n        labels = self.get_ground_truth(device, logits_per_image.shape[0])\n\n        total_loss = (\n            F.cross_entropy(logits_per_image, labels, reduction='none') +\n            F.cross_entropy(logits_per_text, labels, reduction='none')\n        ) / 2\n\n        total_loss = torch.mean(total_loss)\n\n        show_loss = all_reduce(total_loss.detach().clone()) / self.world_size\n\n        return {\"contrastive_loss\": total_loss, \"show_loss\": show_loss}\n\n\nclass FlashClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def get_ground_truth(self, device, num_logits) -> torch.Tensor:\n        # calculated ground-truth and cache if enabled\n        if self.prev_num_logits != num_logits or device not in self.labels:\n            labels = torch.arange(num_logits, device=device, dtype=torch.long)\n            self.labels[device] = labels\n            self.prev_num_logits = num_logits\n        else:\n            labels = self.labels[device]\n        return labels\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n\n        all_image_features, all_text_features = gather_features(\n            image_features, text_features,\n            False, False, self.rank, self.world_size, self.use_horovod)\n        labels = self.get_ground_truth(device, all_image_features.shape[0])\n\n        i2t_loss = _cal_flash_loss(logit_scale * all_image_features, all_text_features,  labels)\n        t2i_loss = _cal_flash_loss(logit_scale * all_text_features,  all_image_features, labels)\n        total_loss = (i2t_loss + t2i_loss) / 2\n\n        scale_factor = (total_loss.shape[0] / image_features.shape[0])\n        total_loss = torch.mean(total_loss * scale_factor)\n\n        show_loss = all_reduce(total_loss.detach().clone()) / (self.world_size * scale_factor)\n\n        return {\"contrastive_loss\": total_loss, \"show_loss\": show_loss}\n\n\nclass RingClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.labels = {}\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n\n        q = image_features\n        k = text_features\n        l = logit_scale\n\n        if device in self.labels:\n            labels = self.labels[device]\n        else:\n            labels = torch.arange(q.shape[0], device=device, dtype=torch.long)\n            self.labels[device] = labels\n\n        i2t_loss = cal_ring_loss(q, k, labels, l)\n        t2i_loss = cal_ring_loss(k, q, labels, l)\n\n        total_loss = (i2t_loss + t2i_loss) / 2\n        total_loss = total_loss.mean()\n\n        show_loss = all_reduce(total_loss.detach().clone()) / self.world_size\n\n        return {\"contrastive_loss\": total_loss, \"show_loss\": show_loss}\n\n\nclass InfClipLoss(nn.Module):\n\n    def __init__(\n            self,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.rank = rank\n        self.world_size = world_size\n        self.use_horovod = use_horovod\n\n        # cache state\n        self.labels = {}\n\n    def forward(self, image_features, text_features, logit_scale):\n        device = image_features.device\n\n        q = image_features\n        k = text_features\n        l = logit_scale\n\n        if device in self.labels:\n            labels = self.labels[device]\n        else:\n            labels = torch.arange(q.shape[0], device=device, dtype=torch.long)\n            self.labels[device] = labels\n\n        i2t_loss = cal_inf_loss(q, k, labels, l)\n        t2i_loss = cal_inf_loss(k, q, labels, l)\n\n        total_loss = (i2t_loss + t2i_loss) / 2\n        total_loss = total_loss.mean()\n\n        show_loss = all_reduce(total_loss.detach().clone()) / self.world_size\n\n        return {\"contrastive_loss\": total_loss, \"show_loss\": show_loss}\n\n        # NOTE: debug code for checkint ring loss gradient\n        # rank = dist.get_rank()\n\n        # q = image_features.detach().clone().requires_grad_()\n        # k = text_features.detach().clone().requires_grad_()\n        # l = logit_scale.detach().clone().requires_grad_()\n\n        # all_q, all_k = gather_features(\n        #         q, k,\n        #         self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)\n        # if rank == 0:\n        #     import numpy as np\n        #     np.save('all_q.npy', all_q.detach().cpu().numpy())\n        #     np.save('all_k.npy', all_k.detach().cpu().numpy())\n\n        # labels = torch.arange(all_q.shape[0], device=device, dtype=torch.long)\n\n        # qk = l * all_q @ all_k.T\n        # lse = torch.logsumexp(qk, dim=1)\n        # # numerator = torch.einsum(\"md,md->m\",l * all_q, all_k[labels])\n        # numerator = qk[torch.arange(qk.shape[0]), labels]\n        # i2t_loss = -numerator + lse\n        # lse = torch.logsumexp(qk.T, dim=1)\n        # # numerator = torch.einsum(\"md,md->m\", l * all_k, all_q[labels])\n        # numerator = qk.T[labels, torch.arange(qk.shape[0])]\n        # t2i_loss = -numerator + lse\n        # # i2t_loss = F.cross_entropy(qk, labels, reduction='none')\n        # # t2i_loss = F.cross_entropy(qk.T, labels, reduction='none')\n\n        # total_loss = (i2t_loss + t2i_loss) / 2\n        # total_loss.sum().backward()\n\n        # q1 = image_features.detach().clone().requires_grad_()\n        # k1 = text_features.detach().clone().requires_grad_()\n        # l1 = logit_scale.detach().clone().requires_grad_()\n\n        # labels = torch.arange(q1.shape[0], device=device, dtype=torch.long)\n\n        # i2t_loss1 = _cal_ring_loss(l1 * q1, k1, labels)\n        # t2i_loss1 = _cal_ring_loss(l1 * k1, q1, labels)\n\n        # total_loss1 = (i2t_loss1 + t2i_loss1) / 2\n\n        # q1.retain_grad(); k1.retain_grad(); l1.retain_grad()\n        # total_loss1.sum().backward()\n        # if rank == 1:\n        #     import numpy as np\n        #     np.save('q_r1_grad.npy', q.grad.detach().cpu().numpy())\n        #     np.save('q1_r1_grad.npy', q1.grad.detach().cpu().numpy())\n        #     print(q.grad, q1.grad)\n        #     print(torch.max(torch.abs(q.grad - q1.grad)), torch.max(torch.abs(k.grad - k1.grad)), torch.max(torch.abs(l.grad - l1.grad)))\n        # exit(0)\n\n\nclass CoCaLoss(ClipLoss):\n    def __init__(\n            self,\n            caption_loss_weight,\n            clip_loss_weight,\n            pad_id=0,  # pad_token for open_clip custom tokenizer\n            local_loss=False,\n            gather_with_grad=False,\n            cache_labels=False,\n            rank=0,\n            world_size=1,\n            use_horovod=False,\n    ):\n        super().__init__(\n            local_loss=local_loss,\n            gather_with_grad=gather_with_grad,\n            cache_labels=cache_labels,\n            rank=rank,\n            world_size=world_size,\n            use_horovod=use_horovod\n        )\n\n        self.clip_loss_weight = clip_loss_weight\n        self.caption_loss_weight = caption_loss_weight\n        self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)\n\n    def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):\n        \n        clip_loss = torch.tensor(0)\n        \n        if self.clip_loss_weight:\n            clip_loss = super().forward(image_features, text_features, logit_scale)\n            clip_loss = self.clip_loss_weight * clip_loss\n\n        caption_loss = self.caption_loss(\n            logits.permute(0, 2, 1),\n            labels,\n        )\n        caption_loss = caption_loss * self.caption_loss_weight\n\n        if output_dict:\n            return {\"contrastive_loss\": clip_loss, \"caption_loss\": caption_loss}\n\n        return clip_loss, caption_loss\n\n\nclass DistillClipLoss(ClipLoss):\n\n    def dist_loss(self, teacher_logits, student_logits):\n        return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)\n\n    def forward(\n            self,\n            image_features,\n            text_features,\n            logit_scale,\n            dist_image_features,\n            dist_text_features,\n            dist_logit_scale,\n            output_dict=False,\n    ):\n        logits_per_image, logits_per_text = \\\n            self.get_logits(image_features, text_features, logit_scale)\n\n        dist_logits_per_image, dist_logits_per_text = \\\n            self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)\n\n        labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])\n\n        contrastive_loss = (\n            F.cross_entropy(logits_per_image, labels) +\n            F.cross_entropy(logits_per_text, labels)\n        ) / 2\n\n        distill_loss = (\n            self.dist_loss(dist_logits_per_image, logits_per_image) +\n            self.dist_loss(dist_logits_per_text, logits_per_text)\n        ) / 2\n\n        if output_dict:\n            return {\"contrastive_loss\": contrastive_loss, \"distill_loss\": distill_loss}\n\n        return contrastive_loss, distill_loss\n\n\ndef neighbour_exchange(from_rank, to_rank, tensor, group=None):\n    tensor_recv = torch.zeros_like(tensor)\n    send_op = torch.distributed.P2POp(\n        torch.distributed.isend,\n        tensor,\n        to_rank,\n        group=group,\n    )\n    recv_op = torch.distributed.P2POp(\n        torch.distributed.irecv,\n        tensor_recv,\n        from_rank,\n        group=group,\n    )\n    reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])\n    for req in reqs:\n        req.wait()\n    return tensor_recv\n\n\ndef neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):\n    tensor_from_left = torch.zeros_like(tensor_to_right)\n    tensor_from_right = torch.zeros_like(tensor_to_left)\n    send_op_left = torch.distributed.P2POp(\n        torch.distributed.isend,\n        tensor_to_left,\n        left_rank,\n        group=group,\n    )\n    send_op_right = torch.distributed.P2POp(\n        torch.distributed.isend,\n        tensor_to_right,\n        right_rank,\n        group=group,\n    )\n    recv_op_left = torch.distributed.P2POp(\n        torch.distributed.irecv,\n        tensor_from_left,\n        left_rank,\n        group=group,\n    )\n    recv_op_right = torch.distributed.P2POp(\n        torch.distributed.irecv,\n        tensor_from_right,\n        right_rank,\n        group=group,\n    )\n    reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left])\n    for req in reqs:\n        req.wait()\n    return tensor_from_right, tensor_from_left\n\n\nclass NeighbourExchange(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, from_rank, to_rank, group, tensor):\n        ctx.group = group\n        ctx.from_rank = from_rank\n        ctx.to_rank = to_rank\n        return neighbour_exchange(from_rank, to_rank, tensor, group=group)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),)\n\n\ndef neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):\n    return NeighbourExchange.apply(from_rank, to_rank, group, tensor)\n\n\nclass NeighbourExchangeBidir(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):\n        ctx.group = group\n        ctx.left_rank = left_rank\n        ctx.right_rank = right_rank\n        return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group)\n\n    @staticmethod\n    def backward(ctx, *grad_outputs):\n        return (None, None, None) + \\\n            NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs)\n\n\ndef neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):\n    return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right)\n\n\nclass SigLipLoss(nn.Module):\n    \"\"\" Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343\n\n    @article{zhai2023sigmoid,\n      title={Sigmoid loss for language image pre-training},\n      author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas},\n      journal={arXiv preprint arXiv:2303.15343},\n      year={2023}\n    }\n    \"\"\"\n    def __init__(\n            self,\n            cache_labels=False,\n            rank=0,\n            world_size=1,\n            bidir=True,\n            use_horovod=False,\n    ):\n        super().__init__()\n        self.cache_labels = cache_labels\n        self.rank = rank\n        self.world_size = world_size\n        assert not use_horovod  # FIXME need to look at hvd ops for ring transfers\n        self.use_horovod = use_horovod\n        self.bidir = bidir\n\n        # cache state FIXME cache not currently used, worthwhile?\n        self.prev_num_logits = 0\n        self.labels = {}\n\n    def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor:\n        labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype)\n        if not negative_only:\n            labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels\n        return labels\n\n    def get_logits(self, image_features, text_features, logit_scale, logit_bias=None):\n        logits = logit_scale * image_features @ text_features.T\n        if logit_bias is not None:\n            logits += logit_bias\n        return logits\n\n    def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False):\n        logits = self.get_logits(image_features, text_features, logit_scale, logit_bias)\n        labels = self.get_ground_truth(\n            image_features.device,\n            image_features.dtype,\n            image_features.shape[0],\n            negative_only=negative_only,\n        )\n        loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0]\n        return loss\n\n    def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False):\n        loss = self._loss(image_features, text_features, logit_scale, logit_bias)\n\n        if self.world_size > 1:\n            # exchange text features w/ neighbour world_size - 1 times\n            right_rank = (self.rank + 1) % self.world_size\n            left_rank = (self.rank - 1 + self.world_size) % self.world_size\n            if self.bidir:\n                text_features_to_right = text_features_to_left = text_features\n                num_bidir, remainder = divmod(self.world_size - 1, 2)\n                for i in range(num_bidir):\n                    text_features_recv = neighbour_exchange_bidir_with_grad(\n                        left_rank,\n                        right_rank,\n                        text_features_to_left,\n                        text_features_to_right,\n                    )\n\n                    for f in text_features_recv:\n                        loss += self._loss(\n                            image_features,\n                            f,\n                            logit_scale,\n                            logit_bias,\n                            negative_only=True,\n                        )\n                    text_features_to_left, text_features_to_right = text_features_recv\n\n                if remainder:\n                    text_features_recv = neighbour_exchange_with_grad(\n                        left_rank, right_rank, text_features_to_right)\n\n                    loss += self._loss(\n                        image_features,\n                        text_features_recv,\n                        logit_scale,\n                        logit_bias,\n                        negative_only=True,\n                    )\n            else:\n                text_features_to_right = text_features\n                for i in range(self.world_size - 1):\n                    text_features_from_left = neighbour_exchange_with_grad(\n                        left_rank, right_rank, text_features_to_right)\n\n                    loss += self._loss(\n                        image_features,\n                        text_features_from_left,\n                        logit_scale,\n                        logit_bias,\n                        negative_only=True,\n                    )\n                    text_features_to_right = text_features_from_left\n\n        return {\"contrastive_loss\": loss} if output_dict else loss\n"
  },
  {
    "path": "inf_clip/models/modified_resnet.py",
    "content": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom ..utils import freeze_batch_norm_2d\n\n\nclass Bottleneck(nn.Module):\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1):\n        super().__init__()\n\n        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.act1 = nn.ReLU(inplace=True)\n\n        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.act2 = nn.ReLU(inplace=True)\n\n        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n\n        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n        self.act3 = nn.ReLU(inplace=True)\n\n        self.downsample = None\n        self.stride = stride\n\n        if stride > 1 or inplanes != planes * Bottleneck.expansion:\n            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n            self.downsample = nn.Sequential(OrderedDict([\n                (\"-1\", nn.AvgPool2d(stride)),\n                (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n                (\"1\", nn.BatchNorm2d(planes * self.expansion))\n            ]))\n\n    def forward(self, x: torch.Tensor):\n        identity = x\n\n        out = self.act1(self.bn1(self.conv1(x)))\n        out = self.act2(self.bn2(self.conv2(out)))\n        out = self.avgpool(out)\n        out = self.bn3(self.conv3(out))\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.act3(out)\n        return out\n\n\nclass AttentionPool2d(nn.Module):\n    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n        self.k_proj = nn.Linear(embed_dim, embed_dim)\n        self.q_proj = nn.Linear(embed_dim, embed_dim)\n        self.v_proj = nn.Linear(embed_dim, embed_dim)\n        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n        self.num_heads = num_heads\n\n    def forward(self, x):\n        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC\n        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC\n        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC\n        x, _ = F.multi_head_attention_forward(\n            query=x, key=x, value=x,\n            embed_dim_to_check=x.shape[-1],\n            num_heads=self.num_heads,\n            q_proj_weight=self.q_proj.weight,\n            k_proj_weight=self.k_proj.weight,\n            v_proj_weight=self.v_proj.weight,\n            in_proj_weight=None,\n            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n            bias_k=None,\n            bias_v=None,\n            add_zero_attn=False,\n            dropout_p=0.,\n            out_proj_weight=self.c_proj.weight,\n            out_proj_bias=self.c_proj.bias,\n            use_separate_proj_weight=True,\n            training=self.training,\n            need_weights=False\n        )\n\n        return x[0]\n\n\nclass ModifiedResNet(nn.Module):\n    \"\"\"\n    A ResNet class that is similar to torchvision's but contains the following changes:\n    - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n    - The final pooling layer is a QKV attention instead of an average pool\n    \"\"\"\n\n    def __init__(self, layers, output_dim, heads, image_size=224, width=64):\n        super().__init__()\n        self.output_dim = output_dim\n        self.image_size = image_size\n\n        # the 3-layer stem\n        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(width // 2)\n        self.act1 = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(width // 2)\n        self.act2 = nn.ReLU(inplace=True)\n        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(width)\n        self.act3 = nn.ReLU(inplace=True)\n        self.avgpool = nn.AvgPool2d(2)\n\n        # residual layers\n        self._inplanes = width  # this is a *mutable* variable used during construction\n        self.layer1 = self._make_layer(width, layers[0])\n        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n\n        embed_dim = width * 32  # the ResNet feature dimension\n        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\n\n        self.init_parameters()\n\n    def _make_layer(self, planes, blocks, stride=1):\n        layers = [Bottleneck(self._inplanes, planes, stride)]\n\n        self._inplanes = planes * Bottleneck.expansion\n        for _ in range(1, blocks):\n            layers.append(Bottleneck(self._inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def init_parameters(self):\n        if self.attnpool is not None:\n            std = self.attnpool.c_proj.in_features ** -0.5\n            nn.init.normal_(self.attnpool.q_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.k_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.v_proj.weight, std=std)\n            nn.init.normal_(self.attnpool.c_proj.weight, std=std)\n\n        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\n            for name, param in resnet_block.named_parameters():\n                if name.endswith(\"bn3.weight\"):\n                    nn.init.zeros_(param)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        assert unlocked_groups == 0, 'partial locking not currently supported for this model'\n        for param in self.parameters():\n            param.requires_grad = False\n        if freeze_bn_stats:\n            freeze_batch_norm_2d(self)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        # FIXME support for non-transformer\n        pass\n\n    def stem(self, x):\n        x = self.act1(self.bn1(self.conv1(x)))\n        x = self.act2(self.bn2(self.conv2(x)))\n        x = self.act3(self.bn3(self.conv3(x)))\n        x = self.avgpool(x)\n        return x\n\n    def forward(self, x):\n        x = self.stem(x)\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x = self.attnpool(x)\n\n        return x\n"
  },
  {
    "path": "inf_clip/models/pos_embed.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n# --------------------------------------------------------\n# Position embedding utils\n# --------------------------------------------------------\n\nimport numpy as np\n\nimport torch\n\n# --------------------------------------------------------\n# 2D sine-cosine position embedding\n# References:\n# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py\n# MoCo v3: https://github.com/facebookresearch/moco-v3\n# --------------------------------------------------------\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token:\n        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 2 == 0\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=float)\n    omega /= embed_dim / 2.\n    omega = 1. / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out) # (M, D/2)\n    emb_cos = np.cos(out) # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\n# --------------------------------------------------------\n# Interpolate position embeddings for high-resolution\n# References:\n# DeiT: https://github.com/facebookresearch/deit\n# --------------------------------------------------------\ndef interpolate_pos_embed(model, checkpoint_model):\n    if 'pos_embed' in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model['pos_embed']\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model['pos_embed'] = new_pos_embed\n"
  },
  {
    "path": "inf_clip/models/timm_model.py",
    "content": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.\n\"\"\"\nimport logging\nfrom collections import OrderedDict\n\nimport torch\nimport torch.nn as nn\n\ntry:\n    import timm\n    from timm.models.layers import Mlp, to_2tuple\n    try:\n        # old timm imports < 0.8.1\n        from timm.models.layers.attention_pool2d import RotAttentionPool2d\n        from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d\n    except ImportError:\n        # new timm imports >= 0.8.1\n        from timm.layers import RotAttentionPool2d\n        from timm.layers import AttentionPool2d as AbsAttentionPool2d\nexcept ImportError:\n    timm = None\n\nfrom ..utils import freeze_batch_norm_2d\n\n\nclass TimmModel(nn.Module):\n    \"\"\" timm model adapter\n    \"\"\"\n\n    def __init__(\n            self,\n            model_name,\n            embed_dim,\n            image_size=224,\n            pool='avg',\n            proj='linear',\n            proj_bias=False,\n            drop=0.,\n            drop_path=None,\n            patch_drop=None,\n            pretrained=False,\n    ):\n        super().__init__()\n        if timm is None:\n            raise RuntimeError(\"Please `pip install timm` to use timm models.\")\n        self.image_size = to_2tuple(image_size)\n\n        # setup kwargs that may not be common across all models\n        timm_kwargs = {}\n        if drop_path is not None:\n            timm_kwargs['drop_path_rate'] = drop_path\n        if patch_drop is not None:\n            timm_kwargs['patch_drop_rate'] = patch_drop\n\n        custom_pool = pool in ('abs_attn', 'rot_attn')\n        if proj:\n            assert proj in (\"linear\", \"mlp\", \"none\")\n        extra_proj = proj in (\"linear\", \"mlp\")\n        if not extra_proj and not custom_pool:\n            # use network classifier head as projection if no proj specified and no custom pooling used\n            # if projection is explicitly set to \"none\" will be pass through from network trunk\n            proj_dim = 0 if proj == 'none' else embed_dim\n            self.trunk = timm.create_model(\n                model_name,\n                num_classes=proj_dim,\n                global_pool=pool,\n                pretrained=pretrained,\n                **timm_kwargs,\n            )\n            prev_chs = embed_dim\n        else:\n            self.trunk = timm.create_model(\n                model_name,\n                pretrained=pretrained,\n                **timm_kwargs,\n            )\n            feat_size = self.trunk.default_cfg.get('pool_size', None)\n            feature_ndim = 1 if not feat_size else 2\n            if custom_pool:\n                assert feature_ndim == 2\n                # if attn pooling used, remove both classifier and default pool\n                self.trunk.reset_classifier(0, global_pool='')\n            else:\n                # reset global pool if pool config set, otherwise leave as network default\n                reset_kwargs = dict(global_pool=pool) if pool else {}\n                self.trunk.reset_classifier(0, **reset_kwargs)\n            prev_chs = self.trunk.num_features\n\n        head_layers = OrderedDict()\n\n        # Add custom pooling to head\n        if pool == 'abs_attn':\n            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)\n            prev_chs = embed_dim\n        elif pool == 'rot_attn':\n            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)\n            prev_chs = embed_dim\n\n        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used\n        if proj == 'linear':\n            head_layers['drop'] = nn.Dropout(drop)\n            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)\n        elif proj == 'mlp':\n            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))\n\n        self.head = nn.Sequential(head_layers)\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        \"\"\" lock modules\n        Args:\n            unlocked_groups (int): leave last n layer groups unlocked (default: 0)\n        \"\"\"\n        if not unlocked_groups:\n            # lock full model\n            for param in self.trunk.parameters():\n                param.requires_grad = False\n            if freeze_bn_stats:\n                freeze_batch_norm_2d(self.trunk)\n        else:\n            # NOTE: partial freeze requires latest timm (master) branch and is subject to change\n            try:\n                # FIXME import here until API stable and in an official release\n                from timm.models.helpers import group_parameters, group_modules\n            except ImportError:\n                raise RuntimeError(\n                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')\n            matcher = self.trunk.group_matcher()\n            gparams = group_parameters(self.trunk, matcher)\n            max_layer_id = max(gparams.keys())\n            max_layer_id = max_layer_id - unlocked_groups\n            for group_idx in range(max_layer_id + 1):\n                group = gparams[group_idx]\n                for param in group:\n                    self.trunk.get_parameter(param).requires_grad = False\n            if freeze_bn_stats:\n                gmodules = group_modules(self.trunk, matcher, reverse=True)\n                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}\n                freeze_batch_norm_2d(self.trunk, gmodules)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        try:\n            self.trunk.set_grad_checkpointing(enable)\n        except Exception as e:\n            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')\n\n    def forward_trunk(self, x):\n        return self.trunk(x)\n\n    def forward_head(self, x):\n        return self.head(x)\n\n    def forward(self, x):\n        x = self.forward_trunk(x)\n        x = self.forward_head(x)\n        return x\n"
  },
  {
    "path": "inf_clip/models/tokenizer.py",
    "content": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nimport gzip\nimport html\nimport os\nimport random\nimport string\nfrom functools import lru_cache, partial\nfrom typing import Callable, List, Optional, Union\nimport warnings\n\nimport ftfy\nimport numpy as np\nimport regex as re\nimport torch\n\n# https://stackoverflow.com/q/62691279\nos.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n_nltk_init = False\n\nDEFAULT_CONTEXT_LENGTH = 77  # default context length for OpenAI CLIP\n\n\n@lru_cache()\ndef default_bpe():\n    return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"bpe_simple_vocab_16e6.txt.gz\")\n\n\n@lru_cache()\ndef bytes_to_unicode():\n    \"\"\"\n    Returns list of utf-8 byte and a corresponding list of unicode strings.\n    The reversible bpe codes work on unicode strings.\n    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.\n    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.\n    This is a significant percentage of your normal, say, 32K bpe vocab.\n    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.\n    And avoids mapping to whitespace/control characters the bpe code barfs on.\n    \"\"\"\n    bs = list(range(ord(\"!\"), ord(\"~\")+1))+list(range(ord(\"¡\"), ord(\"¬\")+1))+list(range(ord(\"®\"), ord(\"ÿ\")+1))\n    cs = bs[:]\n    n = 0\n    for b in range(2**8):\n        if b not in bs:\n            bs.append(b)\n            cs.append(2**8+n)\n            n += 1\n    cs = [chr(n) for n in cs]\n    return dict(zip(bs, cs))\n\n\ndef get_pairs(word):\n    \"\"\"Return set of symbol pairs in a word.\n    Word is represented as tuple of symbols (symbols being variable-length strings).\n    \"\"\"\n    pairs = set()\n    prev_char = word[0]\n    for char in word[1:]:\n        pairs.add((prev_char, char))\n        prev_char = char\n    return pairs\n\n\ndef basic_clean(text):\n    text = ftfy.fix_text(text)\n    text = html.unescape(html.unescape(text))\n    return text.strip()\n\n\ndef whitespace_clean(text):\n    text = \" \".join(text.split())\n    text = text.strip()\n    return text\n\n\ndef _clean_canonicalize(x):\n    # basic, remove whitespace, remove punctuation, lower case\n    return canonicalize_text(basic_clean(x))\n\n\ndef _clean_lower(x):\n    # basic, remove whitespace, lower case\n    return whitespace_clean(basic_clean(x)).lower()\n\n\ndef _clean_whitespace(x):\n    # basic, remove whitespace\n    return whitespace_clean(basic_clean(x))\n\n\ndef get_clean_fn(type: str):\n    if type == 'canonicalize':\n        return _clean_canonicalize\n    elif type == 'lower':\n        return _clean_lower\n    elif type == 'whitespace':\n        return _clean_whitespace\n    else:\n        assert False, f\"Invalid clean function ({type}).\"\n\n\ndef canonicalize_text(\n    text,\n    *,\n    keep_punctuation_exact_string=None,\n    trans_punctuation: dict = str.maketrans(\"\", \"\", string.punctuation),\n):\n    \"\"\"Returns canonicalized `text` (lowercase and punctuation removed).\n\n    From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94\n\n    Args:\n      text: string to be canonicalized.\n      keep_punctuation_exact_string: If provided, then this exact string kept.\n        For example providing '{}' will keep any occurrences of '{}' (but will\n        still remove '{' and '}' that appear separately).\n    \"\"\"\n    text = text.replace(\"_\", \" \")\n    if keep_punctuation_exact_string:\n        text = keep_punctuation_exact_string.join(\n            part.translate(trans_punctuation)\n            for part in text.split(keep_punctuation_exact_string)\n        )\n    else:\n        text = text.translate(trans_punctuation)\n    text = text.lower()\n    text = \" \".join(text.split())\n    return text.strip()\n\n\nclass SimpleTokenizer(object):\n    def __init__(\n            self,\n            bpe_path: str = default_bpe(),\n            additional_special_tokens: Optional[List[str]] = None,\n            context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,\n            clean: str = 'lower',\n            reduction_mask: str = ''\n    ):\n        self.byte_encoder = bytes_to_unicode()\n        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n        merges = gzip.open(bpe_path).read().decode(\"utf-8\").split('\\n')\n        merges = merges[1:49152-256-2+1]\n        merges = [tuple(merge.split()) for merge in merges]\n        vocab = list(bytes_to_unicode().values())\n        vocab = vocab + [v+'</w>' for v in vocab]\n        for merge in merges:\n            vocab.append(''.join(merge))\n        special_tokens = ['<start_of_text>', '<end_of_text>']\n        if additional_special_tokens:\n            special_tokens += additional_special_tokens\n        vocab.extend(special_tokens)\n        self.encoder = dict(zip(vocab, range(len(vocab))))\n        self.decoder = {v: k for k, v in self.encoder.items()}\n        self.bpe_ranks = dict(zip(merges, range(len(merges))))\n        self.cache = {t:t for t in special_tokens}\n        special = \"|\".join(special_tokens)\n        self.pat = re.compile(\n            special + r\"\"\"|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\",\n            re.IGNORECASE,\n        )\n        self.vocab_size = len(self.encoder)\n        self.all_special_ids = [self.encoder[t] for t in special_tokens]\n        self.sot_token_id = self.all_special_ids[0]\n        self.eot_token_id = self.all_special_ids[1]\n        self.context_length = context_length\n        self.clean_fn = get_clean_fn(clean)\n        self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None\n\n    def bpe(self, token):\n        if token in self.cache:\n            return self.cache[token]\n        word = tuple(token[:-1]) + ( token[-1] + '</w>',)\n        pairs = get_pairs(word)\n\n        if not pairs:\n            return token+'</w>'\n\n        while True:\n            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))\n            if bigram not in self.bpe_ranks:\n                break\n            first, second = bigram\n            new_word = []\n            i = 0\n            while i < len(word):\n                try:\n                    j = word.index(first, i)\n                    new_word.extend(word[i:j])\n                    i = j\n                except Exception:\n                    new_word.extend(word[i:])\n                    break\n\n                if word[i] == first and i < len(word)-1 and word[i+1] == second:\n                    new_word.append(first+second)\n                    i += 2\n                else:\n                    new_word.append(word[i])\n                    i += 1\n            new_word = tuple(new_word)\n            word = new_word\n            if len(word) == 1:\n                break\n            else:\n                pairs = get_pairs(word)\n        word = ' '.join(word)\n        self.cache[token] = word\n        return word\n\n    def encode(self, text):\n        bpe_tokens = []\n        text = self.clean_fn(text)\n        for token in re.findall(self.pat, text):\n            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n        return bpe_tokens\n\n    def decode(self, tokens):\n        text = ''.join([self.decoder[token] for token in tokens])\n        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('</w>', ' ')\n        return text\n\n    def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor:\n        \"\"\" Returns the tokenized representation of given input string(s)\n\n        Parameters\n        ----------\n        texts : Union[str, List[str]]\n            An input string or a list of input strings to tokenize\n        context_length : int\n            The context length to use; all CLIP models use 77 as the context length\n\n        Returns\n        -------\n        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]\n        \"\"\"\n        if isinstance(texts, str):\n            texts = [texts]\n\n        context_length = context_length or self.context_length\n        assert context_length, 'Please set a valid context length'\n\n        if self.reduction_fn is not None:\n            # use reduction strategy for tokenize if set, otherwise default to truncation below\n            return self.reduction_fn(\n                texts,\n                context_length=context_length,\n                sot_token_id=self.sot_token_id,\n                eot_token_id=self.eot_token_id,\n                encode_fn=self.encode,\n            )\n\n        all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts]\n        result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n        for i, tokens in enumerate(all_tokens):\n            if len(tokens) > context_length:\n                tokens = tokens[:context_length]  # Truncate\n                tokens[-1] = self.eot_token_id\n            result[i, :len(tokens)] = torch.tensor(tokens)\n\n        return result\n\n\n_tokenizer = SimpleTokenizer()\n\n\ndef decode(output_ids: torch.Tensor):\n    output_ids = output_ids.cpu().numpy()\n    return _tokenizer.decode(output_ids)\n\n\ndef tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor:\n    return _tokenizer(texts, context_length=context_length)\n\n\ndef random_mask_tokenize(\n        texts: Union[str, List[str]],\n        context_length: int,\n        sot_token_id: int,\n        eot_token_id: int,\n        encode_fn: Callable,\n        shuffle: bool = False,\n):\n    all_tokens = [encode_fn(text) for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        tokens = torch.tensor(tokens)\n        num_tokens = len(tokens)\n        if num_tokens > context_length - 2:  # 2 for sot and eot token\n            num_keep = context_length - 2\n            indices = torch.randperm(len(tokens))\n            indices = indices[:num_keep]\n            if not shuffle:\n                indices = indices.msort()\n            tokens = tokens[indices]\n            num_tokens = num_keep\n        result[i, 0] = sot_token_id\n        result[i, 1:num_tokens + 1] = tokens\n        result[i, num_tokens + 1] = eot_token_id\n\n    return result\n\n\ndef simple_mask_tokenize(\n        texts: Union[str, List[str]],\n        context_length: int,\n        sot_token_id: int,\n        eot_token_id: int,\n        encode_fn: Callable,\n):\n    all_tokens = [encode_fn(text) for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        num_tokens = len(tokens)\n        if num_tokens > context_length - 2:  # 2 for sot and eot token\n            num_keep = context_length - 2\n            start_index = random.randint(0, num_tokens - num_keep)  # high is incl\n            tokens = tokens[start_index: start_index + num_keep]\n        tokens = [sot_token_id] + tokens + [eot_token_id]\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result\n\n\ndef syntax_mask_tokenize(\n        texts: Union[str, List[str]],\n        context_length: int,\n        sot_token_id: int,\n        eot_token_id: int,\n        encode_fn: Callable,\n) -> torch.LongTensor:\n    \"\"\" Returns the tokenized representation of given input string(s).\n    Apply syntax masking before tokenize.\n    \"\"\"\n    import nltk\n    global _nltk_init\n    if not _nltk_init:\n        # run them for the first time\n        nltk.download('punkt')\n        nltk.download('averaged_perceptron_tagger')\n        _nltk_init = True\n\n    def get_order(x):\n        if x.startswith('NN'):\n            return 1\n        elif x.startswith('JJ'):\n            return 2\n        elif x.startswith('VB'):\n            return 3\n        else:\n            return 4\n\n    # syntax masking\n    new_texts = []\n    for text in texts:\n        list_tokens = nltk.tokenize.word_tokenize(text)\n        pos_tags = nltk.pos_tag(list_tokens)\n        #  sample the words by get_order method\n        order_list = [get_order(tag) for _, tag in pos_tags]\n        sorted_ids = np.argsort(np.array(order_list))\n        sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens\n        sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0)  # sample the tokens\n\n        new_text = ''\n        for token in sampled_tokens:\n            new_text = new_text + str(token) + ' '\n        new_text = new_text.strip()\n        new_texts.append(new_text)\n    texts = new_texts\n\n    all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts]\n    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n\n    for i, tokens in enumerate(all_tokens):\n        # still need first truncate because some words produces two tokens\n        if len(tokens) > context_length:\n            tokens = tokens[:context_length]  # Truncate\n            tokens[-1] = eot_token_id\n        result[i, :len(tokens)] = torch.tensor(tokens)\n\n    return result\n\n\ndef get_reduction_mask_fn(type: str):\n    \"\"\" Choose strategy for dropping (masking) tokens to achieve target context length\"\"\"\n    assert type in ('simple', 'random', 'shuffle', 'syntax')\n    if type == 'simple':\n        return simple_mask_tokenize  # randomly select block [start:end]\n    elif type == 'random':\n        return random_mask_tokenize  # randomly drop tokens (keep order)\n    elif type == 'shuffle':\n        return partial(random_mask_tokenize, shuffle=True)  # randomly drop tokens (shuffle order)\n    elif type == 'syntax':\n        return syntax_mask_tokenize  # randomly drop prioritized by syntax\n\n\nclass HFTokenizer:\n    \"\"\"HuggingFace tokenizer wrapper\"\"\"\n\n    def __init__(\n            self,\n            tokenizer_name: str,\n            context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH,\n            clean: str = 'whitespace',\n            strip_sep_token: bool = False,\n            language: Optional[str] = None,\n            **kwargs\n    ):\n        from transformers import AutoTokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs)\n        set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None)\n        if callable(set_lang_fn):\n            self.set_lang_fn = set_lang_fn\n        if language is not None:\n            self.set_language(language)\n        self.context_length = context_length\n        self.clean_fn = get_clean_fn(clean)\n        self.strip_sep_token = strip_sep_token\n\n    def save_pretrained(self, dest):\n        self.tokenizer.save_pretrained(dest)\n\n    def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:\n        # same cleaning as for default tokenizer, except lowercasing\n        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance\n        if isinstance(texts, str):\n            texts = [texts]\n\n        context_length = context_length or self.context_length\n        assert context_length, 'Please set a valid context length in class init or call.'\n\n        texts = [self.clean_fn(text) for text in texts]\n        input_ids = self.tokenizer.batch_encode_plus(\n            texts,\n            return_tensors='pt',\n            max_length=context_length,\n            padding='max_length',\n            truncation=True,\n        ).input_ids\n\n        if self.strip_sep_token:\n            input_ids = torch.where(\n                input_ids == self.tokenizer.sep_token_id,\n                torch.zeros_like(input_ids),\n                input_ids,\n            )\n\n        return input_ids\n    \n    def set_language(self, src_lang):\n        if hasattr(self, 'set_lang_fn'):\n            self.set_lang_fn(src_lang)\n        else:\n            warnings.warn('Cannot set language for the tokenizer.')\n\n\nclass SigLipTokenizer:\n    \"\"\"HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs\n    \"\"\"\n    VOCAB_FILES = {\n        # english, vocab_size=32_000\n        \"c4-en\": \"http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model\",\n        # used in multilingual models (mT5, PaLI), vocab_size=250_000\n        \"mc4\": \"http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model\",\n    }\n\n    def __init__(\n            self,\n            tokenizer_name: str,\n            context_length: Optional[int] = 64,\n    ):\n        from transformers import T5TokenizerFast\n\n        if tokenizer_name in self.VOCAB_FILES:\n            # FIXME temporary hack?\n            import tempfile\n\n            import fsspec\n            vocab_file = self.VOCAB_FILES[tokenizer_name]\n            with tempfile.NamedTemporaryFile('wb') as dst:\n                with fsspec.open(vocab_file, 'rb') as src:\n                    dst.write(src.read())\n                self.tokenizer = T5TokenizerFast(dst.name, legacy=False)\n        else:\n            self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False)\n\n        self.tokenizer.pad_token_id = 1\n        self.tokenizer.eos_token_id = 1\n        self.context_length = context_length\n\n    def save_pretrained(self, dest):\n        self.tokenizer.save_pretrained(dest)\n\n    def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor:\n        # same cleaning as for default tokenizer, except lowercasing\n        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance\n        if isinstance(texts, str):\n            texts = [texts]\n\n        context_length = context_length or self.context_length\n        assert context_length, 'Please set a valid context length in class init or call.'\n\n        texts = [canonicalize_text(basic_clean(text)) for text in texts]\n        output = self.tokenizer(\n            texts,\n            return_tensors='pt',\n            max_length=context_length,\n            padding='max_length',\n            truncation=True,\n        )\n        return output.input_ids\n"
  },
  {
    "path": "inf_clip/models/transform.py",
    "content": "import numbers\nimport random\nimport warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n\nimport torch\nimport torchvision.transforms.functional as F\nfrom torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n    CenterCrop, ColorJitter, Grayscale\n\nfrom ..constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom ..utils import to_2tuple\n\n\n@dataclass\nclass PreprocessCfg:\n    size: Union[int, Tuple[int, int]] = 224\n    mode: str = 'RGB'\n    mean: Tuple[float, ...] = OPENAI_DATASET_MEAN\n    std: Tuple[float, ...] = OPENAI_DATASET_STD\n    interpolation: str = 'bicubic'\n    resize_mode: str = 'shortest'\n    fill_color: int = 0\n\n    def __post_init__(self):\n        assert self.mode in ('RGB',)\n\n    @property\n    def num_channels(self):\n        return 3\n\n    @property\n    def input_size(self):\n        return (self.num_channels,) + to_2tuple(self.size)\n\n_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())\n\n\ndef merge_preprocess_dict(\n        base: Union[PreprocessCfg, Dict],\n        overlay: Dict,\n):\n    \"\"\" Merge overlay key-value pairs on top of base preprocess cfg or dict.\n    Input dicts are filtered based on PreprocessCfg fields.\n    \"\"\"\n    if isinstance(base, PreprocessCfg):\n        base_clean = asdict(base)\n    else:\n        base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}\n    if overlay:\n        overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}\n        base_clean.update(overlay_clean)\n    return base_clean\n\n\ndef merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):\n    return merge_preprocess_dict(base, kwargs)\n\n\n@dataclass\nclass AugmentationCfg:\n    scale: Tuple[float, float] = (0.9, 1.0)\n    ratio: Optional[Tuple[float, float]] = None\n    color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None\n    re_prob: Optional[float] = None\n    re_count: Optional[int] = None\n    use_timm: bool = False\n\n    # params for simclr_jitter_gray\n    color_jitter_prob: float = None\n    gray_scale_prob: float = None\n\n\ndef _setup_size(size, error_msg):\n    if isinstance(size, numbers.Number):\n        return int(size), int(size)\n\n    if isinstance(size, Sequence) and len(size) == 1:\n        return size[0], size[0]\n\n    if len(size) != 2:\n        raise ValueError(error_msg)\n\n    return size\n\n\nclass ResizeKeepRatio:\n    \"\"\" Resize and Keep Ratio\n\n    Copy & paste from `timm`\n    \"\"\"\n\n    def __init__(\n            self,\n            size,\n            longest=0.,\n            interpolation=InterpolationMode.BICUBIC,\n            random_scale_prob=0.,\n            random_scale_range=(0.85, 1.05),\n            random_aspect_prob=0.,\n            random_aspect_range=(0.9, 1.11)\n    ):\n        if isinstance(size, (list, tuple)):\n            self.size = tuple(size)\n        else:\n            self.size = (size, size)\n        self.interpolation = interpolation\n        self.longest = float(longest)  # [0, 1] where 0 == shortest edge, 1 == longest\n        self.random_scale_prob = random_scale_prob\n        self.random_scale_range = random_scale_range\n        self.random_aspect_prob = random_aspect_prob\n        self.random_aspect_range = random_aspect_range\n\n    @staticmethod\n    def get_params(\n            img,\n            target_size,\n            longest,\n            random_scale_prob=0.,\n            random_scale_range=(0.85, 1.05),\n            random_aspect_prob=0.,\n            random_aspect_range=(0.9, 1.11)\n    ):\n        \"\"\"Get parameters\n        \"\"\"\n        source_size = img.size[::-1]  # h, w\n        h, w = source_size\n        target_h, target_w = target_size\n        ratio_h = h / target_h\n        ratio_w = w / target_w\n        ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)\n        if random_scale_prob > 0 and random.random() < random_scale_prob:\n            ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])\n            ratio_factor = (ratio_factor, ratio_factor)\n        else:\n            ratio_factor = (1., 1.)\n        if random_aspect_prob > 0 and random.random() < random_aspect_prob:\n            aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])\n            ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)\n        size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]\n        return size\n\n    def __call__(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image): Image to be cropped and resized.\n\n        Returns:\n            PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size\n        \"\"\"\n        size = self.get_params(\n            img, self.size, self.longest,\n            self.random_scale_prob, self.random_scale_range,\n            self.random_aspect_prob, self.random_aspect_range\n        )\n        img = F.resize(img, size, self.interpolation)\n        return img\n\n    def __repr__(self):\n        format_string = self.__class__.__name__ + '(size={0}'.format(self.size)\n        format_string += f', interpolation={self.interpolation})'\n        format_string += f', longest={self.longest:.3f})'\n        return format_string\n\n\ndef center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:\n    \"\"\"Center crops and/or pads the given image.\n    If the image is torch Tensor, it is expected\n    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n\n    Args:\n        img (PIL Image or Tensor): Image to be cropped.\n        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,\n            it is used for both directions.\n        fill (int, Tuple[int]): Padding color\n\n    Returns:\n        PIL Image or Tensor: Cropped image.\n    \"\"\"\n    if isinstance(output_size, numbers.Number):\n        output_size = (int(output_size), int(output_size))\n    elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:\n        output_size = (output_size[0], output_size[0])\n\n    _, image_height, image_width = F.get_dimensions(img)\n    crop_height, crop_width = output_size\n\n    if crop_width > image_width or crop_height > image_height:\n        padding_ltrb = [\n            (crop_width - image_width) // 2 if crop_width > image_width else 0,\n            (crop_height - image_height) // 2 if crop_height > image_height else 0,\n            (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,\n            (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,\n        ]\n        img = F.pad(img, padding_ltrb, fill=fill)\n        _, image_height, image_width = F.get_dimensions(img)\n        if crop_width == image_width and crop_height == image_height:\n            return img\n\n    crop_top = int(round((image_height - crop_height) / 2.0))\n    crop_left = int(round((image_width - crop_width) / 2.0))\n    return F.crop(img, crop_top, crop_left, crop_height, crop_width)\n\n\nclass CenterCropOrPad(torch.nn.Module):\n    \"\"\"Crops the given image at the center.\n    If the image is torch Tensor, it is expected\n    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n\n    Args:\n        size (sequence or int): Desired output size of the crop. If size is an\n            int instead of sequence like (h, w), a square crop (size, size) is\n            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).\n    \"\"\"\n\n    def __init__(self, size, fill=0):\n        super().__init__()\n        self.size = _setup_size(size, error_msg=\"Please provide only two dimensions (h, w) for size.\")\n        self.fill = fill\n\n    def forward(self, img):\n        \"\"\"\n        Args:\n            img (PIL Image or Tensor): Image to be cropped.\n\n        Returns:\n            PIL Image or Tensor: Cropped image.\n        \"\"\"\n        return center_crop_or_pad(img, self.size, fill=self.fill)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}(size={self.size})\"\n\n\ndef _convert_to_rgb(image):\n    return image.convert('RGB')\n\n\nclass color_jitter(object):\n    \"\"\"\n    Apply Color Jitter to the PIL image with a specified probability.\n    \"\"\"\n    def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):\n        assert 0. <= p <= 1.\n        self.p = p\n        self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)\n\n    def __call__(self, img):\n        if random.random() < self.p:\n            return self.transf(img)\n        else:\n            return img\n\n\nclass gray_scale(object):\n    \"\"\"\n    Apply Gray Scale to the PIL image with a specified probability.\n    \"\"\"\n    def __init__(self, p=0.2):\n        assert 0. <= p <= 1.\n        self.p = p\n        self.transf = Grayscale(num_output_channels=3)\n\n    def __call__(self, img):\n        if random.random() < self.p:\n            return self.transf(img)\n        else:\n            return img\n\n\ndef image_transform(\n        image_size: Union[int, Tuple[int, int]],\n        is_train: bool,\n        mean: Optional[Tuple[float, ...]] = None,\n        std: Optional[Tuple[float, ...]] = None,\n        resize_mode: Optional[str] = None,\n        interpolation: Optional[str] = None,\n        fill_color: int = 0,\n        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n):\n    mean = mean or OPENAI_DATASET_MEAN\n    if not isinstance(mean, (list, tuple)):\n        mean = (mean,) * 3\n\n    std = std or OPENAI_DATASET_STD\n    if not isinstance(std, (list, tuple)):\n        std = (std,) * 3\n\n    interpolation = interpolation or 'bicubic'\n    assert interpolation in ['bicubic', 'bilinear', 'random']\n    # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set\n    interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC\n\n    resize_mode = resize_mode or 'shortest'\n    assert resize_mode in ('shortest', 'longest', 'squash')\n\n    if isinstance(aug_cfg, dict):\n        aug_cfg = AugmentationCfg(**aug_cfg)\n    else:\n        aug_cfg = aug_cfg or AugmentationCfg()\n\n    normalize = Normalize(mean=mean, std=std)\n\n    if is_train:\n        aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}\n        use_timm = aug_cfg_dict.pop('use_timm', False)\n        if use_timm:\n            from timm.data import create_transform  # timm can still be optional\n            if isinstance(image_size, (tuple, list)):\n                assert len(image_size) >= 2\n                input_size = (3,) + image_size[-2:]\n            else:\n                input_size = (3, image_size, image_size)\n\n            aug_cfg_dict.setdefault('color_jitter', None)  # disable by default\n            # drop extra non-timm items\n            aug_cfg_dict.pop('color_jitter_prob', None)\n            aug_cfg_dict.pop('gray_scale_prob', None)\n\n            train_transform = create_transform(\n                input_size=input_size,\n                is_training=True,\n                hflip=0.,\n                mean=mean,\n                std=std,\n                re_mode='pixel',\n                interpolation=interpolation,\n                **aug_cfg_dict,\n            )\n        else:\n            train_transform = [\n                RandomResizedCrop(\n                    image_size,\n                    scale=aug_cfg_dict.pop('scale'),\n                    interpolation=InterpolationMode.BICUBIC,\n                ),\n                _convert_to_rgb,\n            ]\n            if aug_cfg.color_jitter_prob:\n                assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4\n                train_transform.extend([\n                    color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)\n                ])\n            if aug_cfg.gray_scale_prob:\n                train_transform.extend([\n                    gray_scale(aug_cfg.gray_scale_prob)\n                ])\n            train_transform.extend([\n                ToTensor(),\n                normalize,\n            ])\n            train_transform = Compose(train_transform)\n            if aug_cfg_dict:\n                warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')\n        return train_transform\n    else:\n        if resize_mode == 'longest':\n            transforms = [\n                ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),\n                CenterCropOrPad(image_size, fill=fill_color)\n            ]\n        elif resize_mode == 'squash':\n            if isinstance(image_size, int):\n                image_size = (image_size, image_size)\n            transforms = [\n                Resize(image_size, interpolation=interpolation_mode),\n            ]\n        else:\n            assert resize_mode == 'shortest'\n            if not isinstance(image_size, (tuple, list)):\n                image_size = (image_size, image_size)\n            if image_size[0] == image_size[1]:\n                # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)\n                transforms = [\n                    Resize(image_size[0], interpolation=interpolation_mode)\n                ]\n            else:\n                # resize shortest edge to matching target dim for non-square target\n                transforms = [ResizeKeepRatio(image_size)]\n            transforms += [CenterCrop(image_size)]\n\n        transforms.extend([\n            _convert_to_rgb,\n            ToTensor(),\n            normalize,\n        ])\n        return Compose(transforms)\n\n\ndef image_transform_v2(\n        cfg: PreprocessCfg,\n        is_train: bool,\n        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n):\n    return image_transform(\n        image_size=cfg.size,\n        is_train=is_train,\n        mean=cfg.mean,\n        std=cfg.std,\n        interpolation=cfg.interpolation,\n        resize_mode=cfg.resize_mode,\n        fill_color=cfg.fill_color,\n        aug_cfg=aug_cfg,\n    )\n"
  },
  {
    "path": "inf_clip/models/transformer.py",
    "content": "from collections import OrderedDict\nimport math\nfrom typing import Callable, List, Optional, Sequence, Tuple, Union\nfrom functools import partial\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils.checkpoint import checkpoint\n\nfrom .pos_embed import get_2d_sincos_pos_embed\nfrom ..utils import to_2tuple\n\n\nclass LayerNormFp32(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)\n        return x.to(orig_type)\n\n\nclass LayerNorm(nn.LayerNorm):\n    \"\"\"Subclass torch's LayerNorm (with cast back to input dtype).\"\"\"\n\n    def forward(self, x: torch.Tensor):\n        orig_type = x.dtype\n        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        return x.to(orig_type)\n\n\nclass QuickGELU(nn.Module):\n    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\n    def forward(self, x: torch.Tensor):\n        return x * torch.sigmoid(1.702 * x)\n\n\nclass LayerScale(nn.Module):\n    def __init__(self, dim, init_values=1e-5, inplace=False):\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x):\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n\n\nclass PatchDropout(nn.Module):\n    \"\"\"\n    https://arxiv.org/abs/2212.00794\n    \"\"\"\n\n    def __init__(self, prob, exclude_first_token=True):\n        super().__init__()\n        assert 0 <= prob < 1.\n        self.prob = prob\n        self.exclude_first_token = exclude_first_token  # exclude CLS token\n\n    def forward(self, x):\n        if not self.training or self.prob == 0.:\n            return x\n\n        if self.exclude_first_token:\n            cls_tokens, x = x[:, :1], x[:, 1:]\n        else:\n            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])\n\n        batch = x.size()[0]\n        num_tokens = x.size()[1]\n\n        batch_indices = torch.arange(batch)\n        batch_indices = batch_indices[..., None]\n\n        keep_prob = 1 - self.prob\n        num_patches_keep = max(1, int(num_tokens * keep_prob))\n\n        rand = torch.randn(batch, num_tokens)\n        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices\n\n        x = x[batch_indices, patch_indices_keep]\n\n        if self.exclude_first_token:\n            x = torch.cat((cls_tokens, x), dim=1)\n\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim: int,\n            num_heads: int = 8,\n            qkv_bias: bool = True,\n            scaled_cosine: bool = False,\n            scale_heads: bool = False,\n            logit_scale_max: float = math.log(1. / 0.01),\n            batch_first: bool = True,\n            attn_drop: float = 0.,\n            proj_drop: float = 0.\n    ):\n        super().__init__()\n        self.scaled_cosine = scaled_cosine\n        self.scale_heads = scale_heads\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n        self.scale = self.head_dim ** -0.5\n        self.logit_scale_max = logit_scale_max\n        self.batch_first = batch_first\n        self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')\n\n        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\n        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\n        if qkv_bias:\n            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\n        else:\n            self.in_proj_bias = None\n\n        if self.scaled_cosine:\n            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\n        else:\n            self.logit_scale = None\n        self.attn_drop = nn.Dropout(attn_drop)\n        if self.scale_heads:\n            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\n        else:\n            self.head_scale = None\n        self.out_proj = nn.Linear(dim, dim)\n        self.out_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):\n        if self.batch_first:\n            x = x.transpose(0, 1)\n\n        L, N, C = x.shape\n        q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)\n        q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n        k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n        v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n\n        if attn_mask is not None and attn_mask.dtype == torch.bool:\n            new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\n            new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n            attn_mask = new_attn_mask\n\n        if self.logit_scale is not None:\n            attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\n            logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\n            attn = attn.view(N, self.num_heads, L, L) * logit_scale\n            attn = attn.view(-1, L, L)\n            if attn_mask is not None:\n                attn = attn + attn_mask\n            attn = attn.softmax(dim=-1)\n            attn = self.attn_drop(attn)\n            x = torch.bmm(attn, v)\n        else:\n            if self.use_fsdpa:\n                x = F.scaled_dot_product_attention(\n                    q, k, v,\n                    attn_mask=attn_mask,\n                    dropout_p=self.attn_drop.p if self.training else 0.,\n                )\n            else:\n                q = q * self.scale\n                attn = torch.bmm(q, k.transpose(-1, -2))\n                if attn_mask is not None:\n                    attn += attn_mask\n                attn = attn.softmax(dim=-1)\n                attn = self.attn_drop(attn)\n                x = torch.bmm(attn, v)\n\n        if self.head_scale is not None:\n            x = x.view(N, self.num_heads, L, C) * self.head_scale\n            x = x.view(-1, L, C)\n\n        x = x.transpose(0, 1).reshape(L, N, C)\n\n        if self.batch_first:\n            x = x.transpose(0, 1)\n\n        x = self.out_proj(x)\n        x = self.out_drop(x)\n        return x\n\n\nclass AttentionalPooler(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            context_dim: int,\n            n_head: int = 8,\n            n_queries: int = 256,\n            norm_layer: Callable = LayerNorm,\n    ):\n        super().__init__()\n        self.query = nn.Parameter(torch.randn(n_queries, d_model))\n        self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)\n        self.ln_q = norm_layer(d_model)\n        self.ln_k = norm_layer(context_dim)\n\n    def forward(self, x: torch.Tensor):\n        N = x.shape[0]\n        x = self.ln_k(x)\n        q = self.ln_q(self.query)\n        out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]\n        return out\n\n\nclass ResidualAttentionBlock(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            n_head: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            is_cross_attention: bool = False,\n            batch_first: bool = True,\n    ):\n        super().__init__()\n\n        self.ln_1 = norm_layer(d_model)\n        self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n        if is_cross_attention:\n            self.ln_1_kv = norm_layer(d_model)\n\n        self.ln_2 = norm_layer(d_model)\n        mlp_width = int(d_model * mlp_ratio)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\n            (\"gelu\", act_layer()),\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\n        ]))\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n    def attention(\n            self,\n            q_x: torch.Tensor,\n            k_x: Optional[torch.Tensor] = None,\n            v_x: Optional[torch.Tensor] = None,\n            attn_mask: Optional[torch.Tensor] = None,\n    ):\n        k_x = k_x if k_x is not None else q_x\n        v_x = v_x if v_x is not None else q_x\n\n        attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None\n        return self.attn(\n            q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask\n        )[0]\n\n    def forward(\n            self,\n            q_x: torch.Tensor,\n            k_x: Optional[torch.Tensor] = None,\n            v_x: Optional[torch.Tensor] = None,\n            attn_mask: Optional[torch.Tensor] = None,\n    ):\n        k_x = self.ln_1_kv(k_x) if hasattr(self, \"ln_1_kv\") and k_x is not None else None\n        v_x = self.ln_1_kv(v_x) if hasattr(self, \"ln_1_kv\") and v_x is not None else None\n        x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))\n        x = x + self.ls_2(self.mlp(self.ln_2(x)))\n        return x\n\n\nclass CustomResidualAttentionBlock(nn.Module):\n    def __init__(\n            self,\n            d_model: int,\n            n_head: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            scale_cosine_attn: bool = False,\n            scale_heads: bool = False,\n            scale_attn: bool = False,\n            scale_fc: bool = False,\n            batch_first: bool = True,\n    ):\n        super().__init__()\n\n        self.ln_1 = norm_layer(d_model)\n        self.attn = Attention(\n            d_model,\n            n_head,\n            scaled_cosine=scale_cosine_attn,\n            scale_heads=scale_heads,\n            batch_first=batch_first,\n        )\n        self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()\n        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n        self.ln_2 = norm_layer(d_model)\n        mlp_width = int(d_model * mlp_ratio)\n        self.mlp = nn.Sequential(OrderedDict([\n            (\"c_fc\", nn.Linear(d_model, mlp_width)),\n            (\"gelu\", act_layer()),\n            ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),\n            (\"c_proj\", nn.Linear(mlp_width, d_model))\n        ]))\n        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n\n    def get_reference_weight(self):\n        return self.mlp.c_fc.weight\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))\n        x = x + self.ls_2(self.mlp(self.ln_2(x)))\n        return x\n\n\ndef _expand_token(token, batch_size: int):\n    return token.view(1, 1, -1).expand(batch_size, -1, -1)\n\n\nclass Transformer(nn.Module):\n    def __init__(\n            self,\n            width: int,\n            layers: int,\n            heads: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            batch_first: bool = True,\n    ):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.batch_first = batch_first\n        self.grad_checkpointing = False\n\n        self.resblocks = nn.ModuleList([\n            ResidualAttentionBlock(\n                width,\n                heads,\n                mlp_ratio,\n                ls_init_value=ls_init_value,\n                act_layer=act_layer,\n                norm_layer=norm_layer,\n                batch_first=batch_first,\n            )\n            for _ in range(layers)\n        ])\n\n    def get_cast_dtype(self) -> torch.dtype:\n        if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):\n            return self.resblocks[0].mlp.c_fc.int8_original_dtype\n        return self.resblocks[0].mlp.c_fc.weight.dtype\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        if not self.batch_first:\n            x = x.transpose(0, 1).contiguous()    # NLD -> LND\n        for r in self.resblocks:\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n                x = checkpoint(r, x, None, None, attn_mask)\n            else:\n                x = r(x, attn_mask=attn_mask)\n        if not self.batch_first:\n            x = x.transpose(0, 1)    # LND -> NLD\n        return x\n\n\nclass CustomTransformer(nn.Module):\n    \"\"\" A custom transformer that can use different block types. \"\"\"\n    def __init__(\n            self,\n            width: int,\n            layers: int,\n            heads: int,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            batch_first: bool = True,\n            block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',\n    ):\n        super().__init__()\n        self.width = width\n        self.layers = layers\n        self.batch_first = batch_first  # run trasnformer stack in batch first (N, L, D)\n        self.grad_checkpointing = False\n\n        if isinstance(block_types, str):\n            block_types = [block_types] * layers\n        assert len(block_types) == layers\n\n        def _create_block(bt: str):\n            if bt == 'CustomResidualAttentionBlock':\n                return CustomResidualAttentionBlock(\n                    width,\n                    heads,\n                    mlp_ratio=mlp_ratio,\n                    ls_init_value=ls_init_value,\n                    act_layer=act_layer,\n                    norm_layer=norm_layer,\n                    batch_first=batch_first,\n                )\n            else:\n                assert False\n\n        self.resblocks = nn.ModuleList([\n            _create_block(bt)\n            for bt in block_types\n        ])\n\n    def get_cast_dtype(self) -> torch.dtype:\n        weight = self.resblocks[0].get_reference_weight()\n        if hasattr(weight, 'int8_original_dtype'):\n            return weight.int8_original_dtype\n        return weight.dtype\n\n    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n        if not self.batch_first:\n            x = x.transpose(0, 1)  # NLD -> LND\n\n        for r in self.resblocks:\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n                x = checkpoint(r, x, None, None, attn_mask)\n            else:\n                x = r(x, attn_mask=attn_mask)\n\n        if not self.batch_first:\n            x = x.transpose(0, 1)  # NLD -> LND\n        return x\n\n\nclass VisionTransformer(nn.Module):\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            image_size: int,\n            patch_size: int,\n            width: int,\n            layers: int,\n            heads: int,\n            mlp_ratio: float,\n            ls_init_value: float = None,\n            attentional_pool: bool = False,\n            attn_pooler_queries: int = 256,\n            attn_pooler_heads: int = 8,\n            output_dim: int = 512,\n            patch_dropout: float = 0.,\n            no_ln_pre: bool = False,\n            pos_embed_type: str = 'learnable',\n            pool_type: str = 'tok',\n            final_ln_after_pool: bool = False,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            output_tokens: bool = False,\n    ):\n        super().__init__()\n        assert pool_type in ('tok', 'avg', 'none')\n        self.output_tokens = output_tokens\n        image_height, image_width = self.image_size = to_2tuple(image_size)\n        patch_height, patch_width = self.patch_size = to_2tuple(patch_size)\n        self.grid_size = (image_height // patch_height, image_width // patch_width)\n        self.final_ln_after_pool = final_ln_after_pool  # currently ignored w/ attn pool enabled\n        self.output_dim = output_dim\n\n        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)\n\n        # class embeddings and positional embeddings\n        scale = width ** -0.5\n        self.class_embedding = nn.Parameter(scale * torch.randn(width))\n        if pos_embed_type == 'learnable':\n            self.positional_embedding = nn.Parameter(\n                scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))\n        elif pos_embed_type == 'sin_cos_2d':\n            # fixed sin-cos embedding\n            assert self.grid_size[0] == self.grid_size[1],\\\n                'currently sin cos 2d pos embedding only supports square input'\n            self.positional_embedding = nn.Parameter(\n                torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)\n            pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)\n            self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())\n        else:\n            raise ValueError\n\n        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\n        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\n\n        self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)\n        self.transformer = Transformer(\n            width,\n            layers,\n            heads,\n            mlp_ratio,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n\n        if attentional_pool:\n            if isinstance(attentional_pool, str):\n                self.attn_pool_type = attentional_pool\n                self.pool_type = 'none'\n                if attentional_pool in ('parallel', 'cascade'):\n                    self.attn_pool = AttentionalPooler(\n                        output_dim,\n                        width,\n                        n_head=attn_pooler_heads,\n                        n_queries=attn_pooler_queries,\n                    )\n                    self.attn_pool_contrastive = AttentionalPooler(\n                        output_dim,\n                        width,\n                        n_head=attn_pooler_heads,\n                        n_queries=1,\n                    )\n                else:\n                    assert False\n            else:\n                self.attn_pool_type = ''\n                self.pool_type = pool_type\n                self.attn_pool = AttentionalPooler(\n                    output_dim,\n                    width,\n                    n_head=attn_pooler_heads,\n                    n_queries=attn_pooler_queries,\n                )\n                self.attn_pool_contrastive = None\n            pool_dim = output_dim\n        else:\n            self.attn_pool = None\n            pool_dim = width\n            self.pool_type = pool_type\n\n        self.ln_post = norm_layer(pool_dim)\n        self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))\n\n        self.init_parameters()\n\n    def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n        for param in self.parameters():\n            param.requires_grad = False\n\n        if unlocked_groups != 0:\n            groups = [\n                [\n                    self.conv1,\n                    self.class_embedding,\n                    self.positional_embedding,\n                    self.ln_pre,\n                ],\n                *self.transformer.resblocks[:-1],\n                [\n                    self.transformer.resblocks[-1],\n                    self.ln_post,\n                ],\n                self.proj,\n            ]\n\n            def _unlock(x):\n                if isinstance(x, Sequence):\n                    for g in x:\n                        _unlock(g)\n                else:\n                    if isinstance(x, torch.nn.Parameter):\n                        x.requires_grad = True\n                    else:\n                        for p in x.parameters():\n                            p.requires_grad = True\n\n            _unlock(groups[-unlocked_groups:])\n\n    def init_parameters(self):\n        # FIXME OpenAI CLIP did not define an init for the VisualTransformer\n        # TODO experiment if default PyTorch init, below, or alternate init is best.\n\n        # nn.init.normal_(self.class_embedding, std=self.scale)\n        # nn.init.normal_(self.positional_embedding, std=self.scale)\n        #\n        # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        # attn_std = self.transformer.width ** -0.5\n        # fc_std = (2 * self.transformer.width) ** -0.5\n        # for block in self.transformer.resblocks:\n        #     nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n        #     nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n        #     nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n        #     nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n        #\n        # if self.text_projection is not None:\n        #     nn.init.normal_(self.text_projection, std=self.scale)\n        pass\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.grad_checkpointing = enable\n\n    def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n        if self.pool_type == 'avg':\n            pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]\n        elif self.pool_type == 'tok':\n            pooled, tokens = x[:, 0], x[:, 1:]\n        else:\n            pooled = tokens = x\n\n        return pooled, tokens\n\n    def forward(self, x: torch.Tensor):\n        x = self.conv1(x)  # shape = [*, width, grid, grid]\n        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]\n        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]\n\n        # class embeddings and positional embeddings\n        x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)\n        # shape = [*, grid ** 2 + 1, width]\n        x = x + self.positional_embedding.to(x.dtype)\n\n        x = self.patch_dropout(x)\n        x = self.ln_pre(x)\n        x = self.transformer(x)\n\n        if self.attn_pool is not None:\n            if self.attn_pool_contrastive is not None:\n                # This is untested, WIP pooling that should match paper\n                x = self.ln_post(x)  # TBD LN first or separate one after each pool?\n                tokens = self.attn_pool(x)\n                if self.attn_pool_type == 'parallel':\n                    pooled = self.attn_pool_contrastive(x)\n                else:\n                    assert self.attn_pool_type == 'cascade'\n                    pooled = self.attn_pool_contrastive(tokens)\n            else:\n                # this is the original OpenCLIP CoCa setup, does not match paper\n                x = self.attn_pool(x)\n                x = self.ln_post(x)\n                pooled, tokens = self._global_pool(x)\n        elif self.final_ln_after_pool:\n            pooled, tokens = self._global_pool(x)\n            pooled = self.ln_post(pooled)\n        else:\n            x = self.ln_post(x)\n            pooled, tokens = self._global_pool(x)\n\n        if self.proj is not None:\n            pooled = pooled @ self.proj\n\n        if self.output_tokens:\n            return pooled, tokens\n        \n        return pooled\n\n\ndef text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):\n    if pool_type == 'first':\n        pooled, tokens = x[:, 0], x[:, 1:]\n    elif pool_type == 'last':\n        pooled, tokens = x[:, -1], x[:, :-1]\n    elif pool_type == 'argmax':\n        # take features from the eot embedding (eot_token is the highest number in each sequence)\n        assert text is not None\n        pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x\n    else:\n        pooled = tokens = x\n\n    return pooled, tokens\n\n\nclass TextTransformer(nn.Module):\n    output_tokens: torch.jit.Final[bool]\n\n    def __init__(\n            self,\n            context_length: int = 77,\n            vocab_size: int = 49408,\n            width: int = 512,\n            heads: int = 8,\n            layers: int = 12,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            output_dim: int = 512,\n            embed_cls: bool = False,\n            no_causal_mask: bool = False,\n            pad_id: int = 0,\n            pool_type: str = 'argmax',\n            proj_bias: bool = False,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            output_tokens: bool = False,\n    ):\n        super().__init__()\n        assert pool_type in ('first', 'last', 'argmax', 'none')\n        self.output_tokens = output_tokens\n        self.num_pos = self.context_length = context_length\n        self.vocab_size = vocab_size\n        self.width = width\n        self.output_dim = output_dim\n        self.heads = heads\n        self.pad_id = pad_id\n        self.pool_type = pool_type\n\n        self.token_embedding = nn.Embedding(vocab_size, width)\n        if embed_cls:\n            self.cls_emb = nn.Parameter(torch.empty(width))\n            self.num_pos += 1\n        else:\n            self.cls_emb = None\n        self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))\n        self.transformer = Transformer(\n            width=width,\n            layers=layers,\n            heads=heads,\n            mlp_ratio=mlp_ratio,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n        )\n        self.ln_final = norm_layer(width)\n\n        if no_causal_mask:\n            self.attn_mask = None\n        else:\n            self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)\n\n        if proj_bias:\n            self.text_projection = nn.Linear(width, output_dim)\n        else:\n            self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n\n        self.init_parameters()\n\n    def init_parameters(self):\n        nn.init.normal_(self.token_embedding.weight, std=0.02)\n        nn.init.normal_(self.positional_embedding, std=0.01)\n        if self.cls_emb is not None:\n            nn.init.normal_(self.cls_emb, std=0.01)\n\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            if isinstance(self.text_projection, nn.Linear):\n                nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)\n                if self.text_projection.bias is not None:\n                    nn.init.zeros_(self.text_projection.bias)\n            else:\n                nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.transformer.grad_checkpointing = enable\n\n    def build_causal_mask(self):\n        # lazily create causal attention mask, with full attention between the tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.num_pos, self.num_pos)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def build_cls_mask(self, text, cast_dtype: torch.dtype):\n        cls_mask = (text != self.pad_id).unsqueeze(1)\n        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)\n        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)\n        additive_mask.fill_(0)\n        additive_mask.masked_fill_(~cls_mask, float(\"-inf\"))\n        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)\n        return additive_mask\n\n    def forward(self, text):\n        cast_dtype = self.transformer.get_cast_dtype()\n        seq_len = text.shape[1]\n\n        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]\n        attn_mask = self.attn_mask\n        if self.cls_emb is not None:\n            seq_len += 1\n            x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)\n            cls_mask = self.build_cls_mask(text, cast_dtype)\n            if attn_mask is not None:\n                attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]\n\n        x = x + self.positional_embedding[:seq_len].to(cast_dtype)\n        x = self.transformer(x, attn_mask=attn_mask)\n\n        # x.shape = [batch_size, n_ctx, transformer.width]\n        if self.cls_emb is not None:\n            # presence of appended cls embed (CoCa) overrides pool_type, always take last token\n            pooled, tokens = text_global_pool(x, pool_type='last')\n            pooled = self.ln_final(pooled)  # final LN applied after pooling in this case\n        else:\n            x = self.ln_final(x)\n            pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type)\n\n        if self.text_projection is not None:\n            if isinstance(self.text_projection, nn.Linear):\n                pooled = self.text_projection(pooled)\n            else:\n                pooled = pooled @ self.text_projection\n\n        if self.output_tokens:\n            return pooled, tokens\n\n        return pooled\n\n\nclass MultimodalTransformer(Transformer):\n    def __init__(\n            self,\n            width: int,\n            layers: int,\n            heads: int,\n            context_length: int = 77,\n            mlp_ratio: float = 4.0,\n            ls_init_value: float = None,\n            act_layer: Callable = nn.GELU,\n            norm_layer: Callable = LayerNorm,\n            output_dim: int = 512,\n            batch_first: bool = True,\n    ):\n        super().__init__(\n            width=width,\n            layers=layers,\n            heads=heads,\n            mlp_ratio=mlp_ratio,\n            ls_init_value=ls_init_value,\n            act_layer=act_layer,\n            norm_layer=norm_layer,\n            batch_first=batch_first,\n        )\n        self.context_length = context_length\n        self.cross_attn = nn.ModuleList([\n            ResidualAttentionBlock(\n                width,\n                heads,\n                mlp_ratio,\n                ls_init_value=ls_init_value,\n                act_layer=act_layer,\n                norm_layer=norm_layer,\n                is_cross_attention=True,\n                batch_first=batch_first,\n            )\n            for _ in range(layers)\n        ])\n\n        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\n\n        self.ln_final = norm_layer(width)\n        self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n\n    def init_parameters(self):\n        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n        attn_std = self.transformer.width ** -0.5\n        fc_std = (2 * self.transformer.width) ** -0.5\n        for block in self.transformer.resblocks:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n        for block in self.transformer.cross_attn:\n            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n\n        if self.text_projection is not None:\n            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n\n    def build_attention_mask(self):\n        # lazily create causal attention mask, with full attention between the tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(self.context_length, self.context_length)\n        mask.fill_(float(\"-inf\"))\n        mask.triu_(1)  # zero out the lower diagonal\n        return mask\n\n    def forward(self, image_embs, text_embs):\n        seq_len = text_embs.shape[1]\n        if not self.batch_first:\n            image_embs = image_embs.permute(1, 0, 2)  # NLD -> LND\n            text_embs = text_embs.permute(1, 0, 2)  # NLD -> LND\n\n        for resblock, cross_attn in zip(self.resblocks, self.cross_attn):\n            if self.grad_checkpointing and not torch.jit.is_scripting():\n                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n                text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len])\n                text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None)\n            else:\n                text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])\n                text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)\n\n        if not self.batch_first:\n            text_embs = text_embs.permute(1, 0, 2)  # LND -> NLD\n\n        out = self.ln_final(text_embs)\n        if self.text_projection is not None:\n            out = out @ self.text_projection\n\n        return out\n\n    @torch.jit.ignore\n    def set_grad_checkpointing(self, enable=True):\n        self.grad_checkpointing = enable\n"
  },
  {
    "path": "inf_clip/openai.py",
    "content": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\n\nimport os\nimport warnings\nfrom typing import List, Optional, Union\n\nimport torch\n\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url\nfrom .models.clip_arch import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype\n\n\n__all__ = [\"list_openai_models\", \"load_openai_model\"]\n\n\ndef list_openai_models() -> List[str]:\n    \"\"\"Returns the names of available CLIP models\"\"\"\n    return list_pretrained_models_by_tag('openai')\n\n\ndef load_openai_model(\n        name: str,\n        precision: Optional[str] = None,\n        device: Optional[Union[str, torch.device]] = None,\n        cache_dir: Optional[str] = None,\n):\n    \"\"\"Load a CLIP model\n\n    Parameters\n    ----------\n    name : str\n        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict\n    precision: str\n        Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.\n    device : Union[str, torch.device]\n        The device to put the loaded model\n    cache_dir : Optional[str]\n        The directory to cache the downloaded model weights\n\n    Returns\n    -------\n    model : torch.nn.Module\n        The CLIP model\n    preprocess : Callable[[PIL.Image], torch.Tensor]\n        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input\n    \"\"\"\n    if device is None:\n        device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n    if precision is None:\n        precision = 'fp32' if device == 'cpu' else 'fp16'\n\n    if get_pretrained_url(name, 'openai'):\n        model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)\n    elif os.path.isfile(name):\n        model_path = name\n    else:\n        raise RuntimeError(f\"Model {name} not found; available models = {list_openai_models()}\")\n\n    try:\n        # loading JIT archive\n        model = torch.jit.load(model_path, map_location=\"cpu\").eval()\n        state_dict = None\n    except RuntimeError:\n        # loading saved state dict\n        state_dict = torch.load(model_path, map_location=\"cpu\")\n\n    # Build a non-jit model from the OpenAI jitted model state dict\n    cast_dtype = get_cast_dtype(precision)\n    try:\n        model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)\n    except KeyError:\n        sd = {k[7:]: v for k, v in state_dict[\"state_dict\"].items()}\n        model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)\n\n    # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use\n    model = model.to(device)\n    # FIXME support pure fp16/bf16 precision modes\n    if precision != 'fp16':\n        model.float()\n        if precision == 'bf16':\n            # for bf16, convert back to low-precision\n            convert_weights_to_lp(model, dtype=torch.bfloat16)\n\n    # add mean / std attributes for consistency with OpenCLIP models\n    model.visual.image_mean = OPENAI_DATASET_MEAN\n    model.visual.image_std = OPENAI_DATASET_STD\n    return model\n"
  },
  {
    "path": "inf_clip/pretrained.py",
    "content": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom .models.clip_arch import CLIP, CustomTextCLIP\nfrom .models.transformer import TextTransformer, Transformer\nfrom .constants import (\n    IMAGENET_MEAN,\n    IMAGENET_STD,\n    INCEPTION_MEAN,\n    INCEPTION_STD,\n    OPENAI_DATASET_MEAN,\n    OPENAI_DATASET_STD,\n)\n\n__version__ = \"2.26.1\"\n\ntry:\n    from huggingface_hub import hf_hub_download\n    hf_hub_download = partial(hf_hub_download, library_name=\"open_clip\", library_version=__version__)\n    _has_hf_hub = True\nexcept ImportError:\n    hf_hub_download = None\n    _has_hf_hub = False\n\n\ndef _pcfg(url='', hf_hub='', **kwargs):\n    # OpenAI / OpenCLIP defaults\n    return {\n        'url': url,\n        'hf_hub': hf_hub,\n        'mean': OPENAI_DATASET_MEAN,\n        'std': OPENAI_DATASET_STD,\n        'interpolation': 'bicubic',\n        'resize_mode': 'shortest',\n        **kwargs,\n    }\n\n\ndef _slpcfg(url='', hf_hub='', **kwargs):\n    # SiGLIP defaults\n    return {\n        'url': url,\n        'hf_hub': hf_hub,\n        'mean': INCEPTION_MEAN,\n        'std': INCEPTION_STD,\n        'interpolation': 'bicubic',\n        'resize_mode': 'squash',\n        **kwargs,\n    }\n\n\ndef _apcfg(url='', hf_hub='', **kwargs):\n    # CLIPA defaults\n    return {\n        'url': url,\n        'hf_hub': hf_hub,\n        'mean': IMAGENET_MEAN,\n        'std': IMAGENET_STD,\n        'interpolation': 'bilinear',\n        'resize_mode': 'squash',\n        **kwargs,\n    }\n\n\ndef _mccfg(url='', hf_hub='', **kwargs):\n    # MobileCLIP\n    return {\n        'url': url,\n        'hf_hub': hf_hub,\n        'mean': (0., 0., 0.),\n        'std': (1., 1., 1.),\n        'interpolation': 'bilinear',\n        'resize_mode': 'shortest',\n        **kwargs,\n    }\n\n\n\n_RN50 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\"),\n    cc12m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\"),\n)\n\n_RN50_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\"),\n    cc12m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\"),\n)\n\n_RN101 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\"),\n)\n\n_RN101_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\"),\n    yfcc15m=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\"),\n)\n\n_RN50x4 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\"),\n)\n\n_RN50x16 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\"),\n)\n\n_RN50x64 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt\"),\n)\n\n_VITB32 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\n    laion2b_e16=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth\"),\n    laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),\n    # DataComp-XL models\n    datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),\n    # DataComp-M models\n    datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),\n    commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),\n    commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),\n    commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),\n    commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),\n    commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),\n    commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),\n    # DataComp-S models\n    datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),\n    commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),\n    commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),\n    commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),\n    commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),\n    commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),\n    commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),\n)\n\n_VITB32_quickgelu = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\"),\n    metaclip_400m=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt\"),\n    metaclip_fullcc=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt\"),\n)\n\n_VITB32_256 = dict(\n    datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),\n)\n\n_VITB16 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt\"),\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),\n    # DataComp-XL models\n    datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),\n    # DataComp-L models\n    datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),\n    commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),\n    commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),\n    commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),\n    commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),\n    commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),\n    commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),\n    # DFN\n    dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/')\n)\n\n_VITB16_quickgelu = dict(\n    metaclip_400m=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt\"),\n    metaclip_fullcc=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt\"),\n)\n\n_VITB16_PLUS_240 = dict(\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt\"),\n)\n\n_VITL14 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\"),\n    laion400m_e31=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt\"),\n    laion400m_e32=_pcfg(\n        \"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt\"),\n    laion2b_s32b_b82k=_pcfg(\n        hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',\n        mean=INCEPTION_MEAN, std=INCEPTION_STD),\n    # DataComp-XL models\n    datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),\n    commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),\n    commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),\n    commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),\n)\n\n_VITL14_quickgelu = dict(\n    metaclip_400m=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt\"),\n    metaclip_fullcc=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt\"),\n    dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'),\n)\n\n_VITL14_336 = dict(\n    openai=_pcfg(\n        \"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\"),\n)\n\n_VITH14 = dict(\n    laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),\n)\n\n_VITH14_quickgelu = dict(\n    metaclip_fullcc=_pcfg(\n        \"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt\"),\n    dfn5b=_pcfg(\n        hf_hub='apple/DFN5B-CLIP-ViT-H-14/',\n        interpolation=\"bicubic\",\n        resize_mode=\"squash\"\n    ),\n)\n\n_VITH14_378_quickgelu = dict(\n    dfn5b=_pcfg(\n        hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',\n        interpolation=\"bicubic\",\n        resize_mode=\"squash\"\n    ),\n)\n\n_VITg14 = dict(\n    laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),\n    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),\n)\n\n_VITbigG14 = dict(\n    laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),\n)\n\n_robertaViTB32 = dict(\n    laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),\n)\n\n_xlmRobertaBaseViTB32 = dict(\n    laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),\n)\n\n_xlmRobertaLargeFrozenViTH14 = dict(\n    frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),\n)\n\n_convnext_base = dict(\n    laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),\n)\n\n_convnext_base_w = dict(\n    laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),\n    laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),\n    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),\n)\n\n_convnext_base_w_320 = dict(\n    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),\n    laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),\n)\n\n_convnext_large_d = dict(\n    laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),\n)\n\n_convnext_large_d_320 = dict(\n    laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),\n    laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),\n)\n\n_convnext_xxlarge = dict(\n    laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),\n    laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),\n    laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),\n)\n\n_coca_VITB32 = dict(\n    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),\n    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')\n)\n\n_coca_VITL14 = dict(\n    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),\n    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')\n)\n\n\n_PRETRAINED = {\n    \"RN50\": _RN50,\n    \"RN50-quickgelu\": _RN50_quickgelu,\n    \"RN101\": _RN101,\n    \"RN101-quickgelu\": _RN101_quickgelu,\n    \"RN50x4\": _RN50x4,\n    \"RN50x16\": _RN50x16,\n    \"RN50x64\": _RN50x64,\n\n    \"ViT-B-32\": _VITB32,\n    \"ViT-B-32-256\": _VITB32_256,\n    \"ViT-B-32-quickgelu\": _VITB32_quickgelu,\n    \"ViT-B-16\": _VITB16,\n    \"ViT-B-16-quickgelu\": _VITB16_quickgelu,\n    \"ViT-B-16-plus-240\": _VITB16_PLUS_240,\n    \"ViT-L-14\": _VITL14,\n    \"ViT-L-14-quickgelu\": _VITL14_quickgelu,\n    \"ViT-L-14-336\": _VITL14_336,\n    \"ViT-H-14\": _VITH14,\n    \"ViT-H-14-quickgelu\": _VITH14_quickgelu,\n    \"ViT-H-14-378-quickgelu\": _VITH14_378_quickgelu,\n    \"ViT-g-14\": _VITg14,\n    \"ViT-bigG-14\": _VITbigG14,\n\n    \"roberta-ViT-B-32\": _robertaViTB32,\n    \"xlm-roberta-base-ViT-B-32\": _xlmRobertaBaseViTB32,\n    \"xlm-roberta-large-ViT-H-14\": _xlmRobertaLargeFrozenViTH14,\n\n    \"convnext_base\": _convnext_base,\n    \"convnext_base_w\": _convnext_base_w,\n    \"convnext_base_w_320\": _convnext_base_w_320,\n    \"convnext_large_d\": _convnext_large_d,\n    \"convnext_large_d_320\": _convnext_large_d_320,\n    \"convnext_xxlarge\": _convnext_xxlarge,\n\n    \"coca_ViT-B-32\": _coca_VITB32,\n    \"coca_ViT-L-14\": _coca_VITL14,\n\n    \"EVA01-g-14\": dict(\n        # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt\n        laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),\n    ),\n    \"EVA01-g-14-plus\": dict(\n        # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt\n        merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),\n    ),\n    \"EVA02-B-16\": dict(\n        # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt\n        merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),\n    ),\n    \"EVA02-L-14\": dict(\n        # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt\n        merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),\n    ),\n    \"EVA02-L-14-336\": dict(\n        # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt\n        merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),\n    ),\n    \"EVA02-E-14\": dict(\n        # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt\n        laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),\n    ),\n    \"EVA02-E-14-plus\": dict(\n        # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt\n        laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),\n    ),\n\n    \"ViT-B-16-SigLIP\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),\n    ),\n    \"ViT-B-16-SigLIP-256\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),\n    ),\n    \"ViT-B-16-SigLIP-i18n-256\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),\n    ),\n    \"ViT-B-16-SigLIP-384\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),\n    ),\n    \"ViT-B-16-SigLIP-512\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),\n    ),\n    \"ViT-L-16-SigLIP-256\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),\n    ),\n    \"ViT-L-16-SigLIP-384\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),\n    ),\n    \"ViT-SO400M-14-SigLIP\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),\n    ),\n    \"ViT-SO400M-14-SigLIP-384\": dict(\n        webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),\n    ),\n\n    \"ViT-L-14-CLIPA\": dict(\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),\n    ),\n    \"ViT-L-14-CLIPA-336\": dict(\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),\n    ),\n    \"ViT-H-14-CLIPA\": dict(\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),\n    ),\n    \"ViT-H-14-CLIPA-336\": dict(\n        laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),\n    ),\n    \"ViT-bigG-14-CLIPA\": dict(\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),\n    ),\n    \"ViT-bigG-14-CLIPA-336\": dict(\n        datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),\n    ),\n\n    \"nllb-clip-base\": dict(\n        v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),\n    ),\n    \"nllb-clip-large\": dict(\n        v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),\n    ),\n\n    \"nllb-clip-base-siglip\": dict(\n        v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),\n        mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),\n    ),\n    \"nllb-clip-large-siglip\": dict(\n        v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),\n        mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),\n    ),\n\n    \"MobileCLIP-S1\": dict(\n        datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),\n    \"MobileCLIP-S2\": dict(\n        datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),\n    \"MobileCLIP-B\": dict(\n        datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),\n        datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),\n    ),\n\n    \"ViTamin-S\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),\n    ),\n    \"ViTamin-S-LTT\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),\n    ),\n    \"ViTamin-B\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),\n    ),\n    \"ViTamin-B-LTT\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),\n    ),\n    \"ViTamin-L\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L-256\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L-336\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L-384\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L2\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L2-256\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L2-336\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),\n    ),\n    \"ViTamin-L2-384\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),\n    ),\n    \"ViTamin-XL-256\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),\n    ),\n    \"ViTamin-XL-336\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),\n    ),\n    \"ViTamin-XL-384\": dict(\n        datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),\n    ),\n}\n\n\ndef _clean_tag(tag: str):\n    # normalize pretrained tags\n    return tag.lower().replace('-', '_')\n\n\ndef list_pretrained(as_str: bool = False):\n    \"\"\" returns list of pretrained models\n    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\n    \"\"\"\n    return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]\n\n\ndef list_pretrained_models_by_tag(tag: str):\n    \"\"\" return all models having the specified pretrain tag \"\"\"\n    models = []\n    tag = _clean_tag(tag)\n    for k in _PRETRAINED.keys():\n        if tag in _PRETRAINED[k]:\n            models.append(k)\n    return models\n\n\ndef list_pretrained_tags_by_model(model: str):\n    \"\"\" return all pretrain tags for the specified model architecture \"\"\"\n    tags = []\n    if model in _PRETRAINED:\n        tags.extend(_PRETRAINED[model].keys())\n    return tags\n\n\ndef is_pretrained_cfg(model: str, tag: str):\n    if model not in _PRETRAINED:\n        return False\n    return _clean_tag(tag) in _PRETRAINED[model]\n\n\ndef get_pretrained_cfg(model: str, tag: str):\n    if model not in _PRETRAINED:\n        return {}\n    model_pretrained = _PRETRAINED[model]\n    return model_pretrained.get(_clean_tag(tag), {})\n\n\ndef get_pretrained_url(model: str, tag: str):\n    cfg = get_pretrained_cfg(model, _clean_tag(tag))\n    return cfg.get('url', '')\n\n\ndef download_pretrained_from_url(\n        url: str,\n        cache_dir: Union[str, None] = None,\n):\n    if not cache_dir:\n        cache_dir = os.path.expanduser(\"~/.cache/clip\")\n    os.makedirs(cache_dir, exist_ok=True)\n    filename = os.path.basename(url)\n\n    if 'openaipublic' in url:\n        expected_sha256 = url.split(\"/\")[-2]\n    elif 'mlfoundations' in url:\n        expected_sha256 = os.path.splitext(filename)[0].split(\"-\")[-1]\n    else:\n        expected_sha256 = ''\n\n    download_target = os.path.join(cache_dir, filename)\n\n    if os.path.exists(download_target) and not os.path.isfile(download_target):\n        raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n\n    if os.path.isfile(download_target):\n        if expected_sha256:\n            if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n                return download_target\n            else:\n                warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n        else:\n            return download_target\n\n    with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n        with tqdm(total=int(source.headers.get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\n            while True:\n                buffer = source.read(8192)\n                if not buffer:\n                    break\n\n                output.write(buffer)\n                loop.update(len(buffer))\n\n    if expected_sha256 and not hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n        raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n\n    return download_target\n\n\ndef has_hf_hub(necessary=False):\n    if not _has_hf_hub and necessary:\n        # if no HF Hub module installed, and it is necessary to continue, raise error\n        raise RuntimeError(\n            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')\n    return _has_hf_hub\n\n\ndef download_pretrained_from_hf(\n        model_id: str,\n        filename: str = 'open_clip_pytorch_model.bin',\n        revision=None,\n        cache_dir: Union[str, None] = None,\n):\n    has_hf_hub(True)\n    cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)\n    return cached_file\n\n\ndef download_pretrained(\n        cfg: Dict,\n        force_hf_hub: bool = False,\n        cache_dir: Union[str, None] = None,\n):\n    target = ''\n    if not cfg:\n        return target\n\n    download_url = cfg.get('url', '')\n    download_hf_hub = cfg.get('hf_hub', '')\n    if download_hf_hub and force_hf_hub:\n        # use HF hub even if url exists\n        download_url = ''\n\n    if download_url:\n        target = download_pretrained_from_url(download_url, cache_dir=cache_dir)\n    elif download_hf_hub:\n        has_hf_hub(True)\n        # we assume the hf_hub entries in pretrained config combine model_id + filename in\n        # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and\n        # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.\n        model_id, filename = os.path.split(download_hf_hub)\n        if filename:\n            target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)\n        else:\n            target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n\n    return target\n\n\n@torch.no_grad()\ndef load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str):\n    \"\"\" Load weights from .npz checkpoints for official Google big_vision image-text models\n\n    Currently the SigLIP source models are supported and a CustomTextCLIP destination model\n    w/ timm image encoder.\n    \"\"\"\n    from timm.layers import resample_patch_embed, resample_abs_pos_embed\n\n    def _n2p(w, t=True):\n        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:\n            w = w.flatten()\n        if t:\n            if w.ndim == 4:\n                w = w.transpose([3, 2, 0, 1])\n            elif w.ndim == 3:\n                w = w.transpose([2, 0, 1])\n            elif w.ndim == 2:\n                w = w.transpose([1, 0])\n        return torch.from_numpy(w)\n\n    w = np.load(checkpoint_path)\n    interpolation = 'bilinear'\n    antialias = False\n\n    def _convert_timm_img(module, prefix):\n        embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])\n        if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]:\n            embed_conv_w = resample_patch_embed(\n                embed_conv_w,\n                module.patch_embed.proj.weight.shape[-2:],\n                interpolation=interpolation,\n                antialias=antialias,\n                verbose=True,\n            )\n        module.patch_embed.proj.weight.copy_(embed_conv_w)\n        module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))\n\n        if module.cls_token is not None:\n            module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))\n\n        pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)\n        if pos_embed_w.shape != module.pos_embed.shape:\n            assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}'\n            num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1)\n            pos_embed_w = resample_abs_pos_embed(  # resize pos embedding when different size from pretrained weights\n                pos_embed_w,\n                new_size=module.patch_embed.grid_size,\n                num_prefix_tokens=num_prefix_tokens,\n                interpolation=interpolation,\n                antialias=antialias,\n                verbose=True,\n            )\n        module.pos_embed.copy_(pos_embed_w)\n\n        mha_sub, b_sub, ln1_sub = (0, 0, 1)\n        for i, block in enumerate(module.blocks.children()):\n            block_prefix = f'{prefix}Transformer/encoderblock_{i}/'\n            mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'\n            block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))\n            block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))\n            block.attn.qkv.weight.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))\n            block.attn.qkv.bias.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))\n            block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))\n            block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))\n            for r in range(2):\n                getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))\n                getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))\n            block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))\n            block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))\n\n        module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))\n        module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))\n\n        if module.attn_pool is not None:\n            block_prefix = f'{prefix}MAPHead_0/'\n            mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'\n            module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))\n            module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)\n            module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))\n            module.attn_pool.kv.weight.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))\n            module.attn_pool.kv.bias.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))\n            module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))\n            module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))\n            module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))\n            module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))\n            for r in range(2):\n                getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))\n                getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))\n\n    def _convert_openclip_transformer(module: Transformer, prefix):\n        for i, block in enumerate(module.resblocks.children()):\n            block_prefix = f'{prefix}encoderblock_{i}/'\n            mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'\n            block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))\n            block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))\n            block.attn.in_proj_weight.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))\n            block.attn.in_proj_bias.copy_(torch.cat([\n                _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))\n            block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))\n            block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))\n            block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale']))\n            block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias']))\n            block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel']))\n            block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias']))\n            block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel']))\n            block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias']))\n\n    def _convert_openclip_txt(module: TextTransformer, prefix):\n        module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False))\n        pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0)\n        module.positional_embedding.copy_(pos_embed_w)\n        _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/')\n        module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale']))\n        module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias']))\n        module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel']))\n        module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias']))\n\n    _convert_timm_img(model.visual.trunk, 'params/img/')\n    _convert_openclip_txt(model.text, 'params/txt/')\n    model.logit_bias.copy_(_n2p(w['params/b'])[0])\n    model.logit_scale.copy_(_n2p(w['params/t'])[0])\n\n\n@torch.no_grad()\ndef convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):\n\n    def _convert_timm_img(state_dict):\n        if fastvit:\n            from timm.models.fastvit import checkpoint_filter_fn\n        else:\n            from timm.models.vision_transformer_hybrid import checkpoint_filter_fn\n        timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)\n        timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}\n        return timm_state_dict\n\n    def _convert_openclip_txt(state_dict, prefix='text_encoder.'):\n        text_dict = {}\n        for k, v in state_dict.items():\n            if not k.startswith(prefix):\n                continue\n            k = k.replace(prefix, '')\n            k = k.replace('projection_layer', 'text_projection')\n            k = k.replace('embedding_layer', 'token_embedding')\n            if k.startswith('positional_embedding.pos_embed.pos_embed'):\n                k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')\n                v = v.squeeze()\n            k = k.replace('final_layer_norm', 'ln_final')\n            k = k.replace('pre_norm_mha.0', 'ln_1')\n            k = k.replace('pre_norm_mha.1', 'attn')\n            k = k.replace('pre_norm_ffn.0', 'ln_2')\n            k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')\n            k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')\n            k = k.replace('qkv_proj.weight', 'in_proj_weight')\n            k = k.replace('qkv_proj.bias', 'in_proj_bias')\n            k = k.replace('transformer.', 'transformer.resblocks.')\n            text_dict['text.' + k] = v\n        return text_dict\n\n    image_dict = _convert_timm_img(state_dict)\n    text_dict = _convert_openclip_txt(state_dict)\n    out_dict = {**image_dict, **text_dict}\n    out_dict['logit_scale'] = state_dict['logit_scale']\n    return out_dict\n\n\ndef convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):\n    if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:\n        # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)\n        state_dict = convert_mobile_clip_state_dict(model, state_dict)\n    if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:\n        # convert b model\n        state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)\n    return state_dict\n"
  },
  {
    "path": "inf_clip/train/data.py",
    "content": "import ast\nimport json\nimport logging\nimport math\nimport os\nimport random\nimport sys\nimport braceexpand\nfrom dataclasses import dataclass\nfrom multiprocessing import Value\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torchvision.datasets as datasets\nimport webdataset as wds\nfrom PIL import Image\nfrom torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info\nfrom torch.utils.data.distributed import DistributedSampler\nfrom webdataset.filters import _shuffle\nfrom webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample\n\n\nclass CsvDataset(Dataset):\n    def __init__(self, input_filename, transforms, img_key, caption_key, sep=\"\\t\", tokenizer=None):\n        logging.debug(f'Loading csv data from {input_filename}.')\n        df = pd.read_csv(input_filename, sep=sep)\n\n        self.images = df[img_key].tolist()\n        self.captions = df[caption_key].tolist()\n        self.transforms = transforms\n        logging.debug('Done loading data.')\n\n        self.tokenize = tokenizer\n\n    def __len__(self):\n        return len(self.captions)\n\n    def __getitem__(self, idx):\n        images = self.transforms(Image.open(str(self.images[idx])))\n        texts = self.tokenize([str(self.captions[idx])])[0]\n        return images, texts\n\n\nclass SharedEpoch:\n    def __init__(self, epoch: int = 0):\n        self.shared_epoch = Value('i', epoch)\n\n    def set_value(self, epoch):\n        self.shared_epoch.value = epoch\n\n    def get_value(self):\n        return self.shared_epoch.value\n\n\n@dataclass\nclass DataInfo:\n    dataloader: DataLoader\n    sampler: DistributedSampler = None\n    shared_epoch: SharedEpoch = None\n\n    def set_epoch(self, epoch):\n        if self.shared_epoch is not None:\n            self.shared_epoch.set_value(epoch)\n        if self.sampler is not None and isinstance(self.sampler, DistributedSampler):\n            self.sampler.set_epoch(epoch)\n\n\ndef expand_urls(urls, weights=None):\n    if weights is None:\n        expanded_urls = wds.shardlists.expand_urls(urls)\n        return expanded_urls, None\n    if isinstance(urls, str):\n        urllist = urls.split(\"::\")\n        weights = weights.split('::')\n        assert len(weights) == len(urllist),\\\n            f\"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match.\"\n        weights = [float(weight) for weight in weights]\n        all_urls, all_weights = [], []\n        for url, weight in zip(urllist, weights):\n            expanded_url = list(braceexpand.braceexpand(url))\n            expanded_weights = [weight for _ in expanded_url]\n            all_urls.extend(expanded_url)\n            all_weights.extend(expanded_weights)\n        return all_urls, all_weights\n    else:\n        all_urls = list(urls)\n        return all_urls, weights\n\n\ndef get_dataset_size(shards):\n    shards_list, _ = expand_urls(shards)\n    dir_path = os.path.dirname(shards_list[0])\n    sizes_filename = os.path.join(dir_path, 'sizes.json')\n    len_filename = os.path.join(dir_path, '__len__')\n    if os.path.exists(sizes_filename):\n        sizes = json.load(open(sizes_filename, 'r'))\n        total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])\n    elif os.path.exists(len_filename):\n        # FIXME this used to be eval(open(...)) but that seemed rather unsafe\n        total_size = ast.literal_eval(open(len_filename, 'r').read())\n    else:\n        total_size = None  # num samples undefined\n        # some common dataset sizes (at time of authors last download)\n        # CC3M (train): 2905954\n        # CC12M: 10968539\n        # LAION-400M: 407332084\n        # LAION-2B (english): 2170337258\n    num_shards = len(shards_list)\n    return total_size, num_shards\n\n\ndef get_imagenet(args, preprocess_fns, split):\n    assert split in [\"train\", \"val\", \"v2\"]\n    is_train = split == \"train\"\n    preprocess_train, preprocess_val = preprocess_fns\n\n    if split == \"v2\":\n        from imagenetv2_pytorch import ImageNetV2Dataset\n        dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)\n    else:\n        if is_train:\n            data_path = args.imagenet_train\n            preprocess_fn = preprocess_train\n        else:\n            data_path = args.imagenet_val\n            preprocess_fn = preprocess_val\n        assert data_path\n\n        dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)\n\n    if is_train:\n        idxs = np.zeros(len(dataset.targets))\n        target_array = np.array(dataset.targets)\n        k = 50\n        for c in range(1000):\n            m = target_array == c\n            n = len(idxs[m])\n            arr = np.zeros(n)\n            arr[:k] = 1\n            np.random.shuffle(arr)\n            idxs[m] = arr\n\n        idxs = idxs.astype('int')\n        sampler = SubsetRandomSampler(np.where(idxs)[0])\n    else:\n        sampler = None\n\n    dataloader = torch.utils.data.DataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        num_workers=args.workers,\n        sampler=sampler,\n        pin_memory=True,\n    )\n\n    return DataInfo(dataloader=dataloader, sampler=sampler)\n\n\ndef count_samples(dataloader):\n    os.environ[\"WDS_EPOCH\"] = \"0\"\n    n_elements, n_batches = 0, 0\n    for images, texts in dataloader:\n        n_batches += 1\n        n_elements += len(images)\n        assert len(images) == len(texts)\n    return n_elements, n_batches\n\n\ndef filter_no_caption_or_no_image(sample):\n    has_caption = ('txt' in sample or 'json' in sample)\n    has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)\n    return has_caption and has_image\n\n\ndef log_and_continue(exn):\n    \"\"\"Call in an exception handler to ignore any exception, issue a warning, and continue.\"\"\"\n    logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')\n    return True\n\n\ndef group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):\n    \"\"\"Return function over iterator that groups key, value pairs into samples.\n\n    :param keys: function that splits the key into key and extension (base_plus_ext)\n    :param lcase: convert suffixes to lower case (Default value = True)\n    \"\"\"\n    current_sample = None\n    for filesample in data:\n        assert isinstance(filesample, dict)\n        # FIXME this is a bit of a hack to handle the fact that the CC3M/LAION400m dataset has some empty files.\n        try:\n            fname, value = filesample[\"fname\"], filesample[\"data\"]\n        except KeyError as exn:\n            continue\n        prefix, suffix = keys(fname)\n        if prefix is None:\n            continue\n        if lcase:\n            suffix = suffix.lower()\n        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for\n        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next\n        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset\n        if current_sample is None or prefix != current_sample[\"__key__\"] or suffix in current_sample:\n            if valid_sample(current_sample):\n                yield current_sample\n            current_sample = dict(__key__=prefix, __url__=filesample[\"__url__\"])\n        if suffixes is None or suffix in suffixes:\n            current_sample[suffix] = value\n    if valid_sample(current_sample):\n        yield current_sample\n\n\ndef tarfile_to_samples_nothrow(src, handler=log_and_continue):\n    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw\n    streams = url_opener(src, handler=handler)\n    files = tar_file_expander(streams, handler=handler)\n    samples = group_by_keys_nothrow(files, handler=handler)\n    return samples\n\n\ndef pytorch_worker_seed(increment=0):\n    \"\"\"get dataloader worker seed from pytorch\"\"\"\n    worker_info = get_worker_info()\n    if worker_info is not None:\n        # favour using the seed already created for pytorch dataloader workers if it exists\n        seed = worker_info.seed\n        if increment:\n            # space out seed increments so they can't overlap across workers in different iterations\n            seed += increment * max(1, worker_info.num_workers)\n        return seed\n    # fallback to wds rank based seed\n    return wds.utils.pytorch_worker_seed()\n\n\ndef json_fetch(data, key='caption'):\n    for sample in data:\n        if 'json' in sample:\n            if isinstance(sample['json'], dict):\n                value = sample['json'].get(key, None)\n            else:\n                value = sample['json']\n        else:\n            value = sample['txt']\n        if isinstance(value, str):\n            sample['txt'] = [value]\n        elif isinstance(value, list):\n            sample['txt'] = value\n        else:\n            # print(f\"Expected {key} to be a string or list of strings, got {type(value)} {sample}\")\n            continue\n\n        yield sample\n\n\n_SHARD_SHUFFLE_SIZE = 2000\n_SHARD_SHUFFLE_INITIAL = 500\n_SAMPLE_SHUFFLE_SIZE = 5000\n_SAMPLE_SHUFFLE_INITIAL = 1000\n\n\nclass detshuffle2(wds.PipelineStage):\n    def __init__(\n            self,\n            bufsize=1000,\n            initial=100,\n            seed=0,\n            epoch=-1,\n    ):\n        self.bufsize = bufsize\n        self.initial = initial\n        self.seed = seed\n        self.epoch = epoch\n\n    def run(self, src):\n        if isinstance(self.epoch, SharedEpoch):\n            epoch = self.epoch.get_value()\n        else:\n            # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)\n            # situation as different workers may wrap at different times (or not at all).\n            self.epoch += 1\n            epoch = self.epoch\n        rng = random.Random()\n        if self.seed < 0:\n            # If seed is negative, we use the worker's seed, this will be different across all nodes/workers\n            seed = pytorch_worker_seed(epoch)\n        else:\n            # This seed to be deterministic AND the same across all nodes/workers in each epoch\n            seed = self.seed + epoch\n        rng.seed(seed)\n        return _shuffle(src, self.bufsize, self.initial, rng)\n\n\nclass ResampledShards2(IterableDataset):\n    \"\"\"An iterable dataset yielding a list of urls.\"\"\"\n\n    def __init__(\n        self,\n        urls,\n        weights=None,\n        nshards=sys.maxsize,\n        worker_seed=None,\n        deterministic=False,\n        epoch=-1,\n    ):\n        \"\"\"Sample shards from the shard list with replacement.\n\n        :param urls: a list of URLs as a Python list or brace notation string\n        \"\"\"\n        super().__init__()\n        urls, weights = expand_urls(urls, weights)\n        self.urls = urls\n        self.weights = weights\n        if self.weights is not None:\n            assert len(self.urls) == len(self.weights),\\\n                f\"Number of urls {len(self.urls)} and weights {len(self.weights)} should match.\"\n        assert isinstance(self.urls[0], str)\n        self.nshards = nshards\n        self.rng = random.Random()\n        self.worker_seed = worker_seed\n        self.deterministic = deterministic\n        self.epoch = epoch\n\n    def __iter__(self):\n        \"\"\"Return an iterator over the shards.\"\"\"\n        if isinstance(self.epoch, SharedEpoch):\n            epoch = self.epoch.get_value()\n        else:\n            # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)\n            # situation as different workers may wrap at different times (or not at all).\n            self.epoch += 1\n            epoch = self.epoch\n        if self.deterministic:\n            # reset seed w/ epoch if deterministic\n            if self.worker_seed is None:\n                # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id\n                seed = pytorch_worker_seed(epoch)\n            else:\n                seed = self.worker_seed() + epoch\n            self.rng.seed(seed)\n        for _ in range(self.nshards):\n            if self.weights is None:\n                yield dict(url=self.rng.choice(self.urls))\n            else:\n                yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])\n\n\ndef get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False, tokenizer=None):\n    input_shards = args.train_data if is_train else args.val_data\n    assert input_shards is not None\n    resampled = getattr(args, 'dataset_resampled', False) and is_train\n\n    num_shards = None\n    if is_train:\n        if args.train_num_samples is not None:\n            num_samples = args.train_num_samples\n        else:\n            num_samples, num_shards = get_dataset_size(input_shards)\n            if not num_samples:\n                raise RuntimeError(\n                    'Currently, the number of dataset samples must be specified for the training dataset. '\n                    'Please specify it via `--train-num-samples` if no dataset length info is present.')\n    else:\n        # Eval will just exhaust the iterator if the size is not specified.\n        num_samples = args.val_num_samples or 0 \n\n    shared_epoch = SharedEpoch(epoch=epoch)  # create a shared epoch store to sync epoch to dataloader worker proc\n\n    if is_train and args.train_data_upsampling_factors is not None:\n        assert resampled, \"--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled).\"\n    \n    if resampled:\n        pipeline = [ResampledShards2(\n            input_shards,\n            weights=args.train_data_upsampling_factors,\n            deterministic=True,\n            epoch=shared_epoch,\n        )]\n    else:\n        pipeline = [wds.SimpleShardList(input_shards)]\n\n    # at this point we have an iterator over all the shards\n    if is_train:\n        if not resampled:\n            pipeline.extend([\n                detshuffle2(\n                    bufsize=_SHARD_SHUFFLE_SIZE,\n                    initial=_SHARD_SHUFFLE_INITIAL,\n                    seed=args.seed,\n                    epoch=shared_epoch,\n                ),\n                wds.split_by_node,\n                wds.split_by_worker,\n            ])\n        pipeline.extend([\n            # at this point, we have an iterator over the shards assigned to each worker at each node\n            tarfile_to_samples_nothrow,  # wds.tarfile_to_samples(handler=log_and_continue),\n            wds.shuffle(\n                bufsize=_SAMPLE_SHUFFLE_SIZE,\n                initial=_SAMPLE_SHUFFLE_INITIAL,\n            ),\n        ])\n    else:\n        pipeline.extend([\n            wds.split_by_worker,\n            # at this point, we have an iterator over the shards assigned to each worker\n            wds.tarfile_to_samples(handler=log_and_continue),\n        ])\n    pipeline.extend([\n        wds.select(filter_no_caption_or_no_image),\n        wds.decode(\"pilrgb\", handler=log_and_continue),\n        json_fetch,\n        wds.rename(image=\"jpg;png;jpeg;webp\", text=\"txt;json\"),\n        wds.map_dict(image=preprocess_img, text=lambda text: tokenizer(text)[0]),\n        wds.to_tuple(\"image\", \"text\"),\n        wds.batched(args.batch_size, partial=not is_train)\n    ])\n\n    dataset = wds.DataPipeline(*pipeline)\n\n    if is_train:\n        if not resampled:\n            num_shards = num_shards or len(expand_urls(input_shards)[0])\n            assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'\n        # roll over and repeat a few samples to get same number of full batches on each node\n        round_fn = math.floor if floor else math.ceil\n        global_batch_size = args.batch_size * args.world_size\n        num_batches = round_fn(num_samples / global_batch_size)\n        num_workers = max(1, args.workers)\n        num_worker_batches = round_fn(num_batches / num_workers)  # per dataloader worker\n        num_batches = num_worker_batches * num_workers\n        num_samples = num_batches * global_batch_size\n        dataset = dataset.with_epoch(num_worker_batches)  # each worker is iterating over this\n    else:\n        # last batches are partial, eval is done on single (master) node\n        num_batches = math.ceil(num_samples / args.batch_size)\n\n    dataloader = wds.WebLoader(\n        dataset,\n        batch_size=None,\n        shuffle=False,\n        num_workers=args.workers,\n        persistent_workers=args.workers > 0,\n    )\n\n    # FIXME not clear which approach is better, with_epoch before vs after dataloader?\n    # hoping to resolve via https://github.com/webdataset/webdataset/issues/169\n    # if is_train:\n    #     # roll over and repeat a few samples to get same number of full batches on each node\n    #     global_batch_size = args.batch_size * args.world_size\n    #     num_batches = math.ceil(num_samples / global_batch_size)\n    #     num_workers = max(1, args.workers)\n    #     num_batches = math.ceil(num_batches / num_workers) * num_workers\n    #     num_samples = num_batches * global_batch_size\n    #     dataloader = dataloader.with_epoch(num_batches)\n    # else:\n    #     # last batches are partial, eval is done on single (master) node\n    #     num_batches = math.ceil(num_samples / args.batch_size)\n\n    # add meta-data to dataloader instance for convenience\n    dataloader.num_batches = num_batches\n    dataloader.num_samples = num_samples\n    dataloader.batch_size = args.batch_size\n\n    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)\n\n\ndef get_csv_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):\n    input_filename = args.train_data if is_train else args.val_data\n    assert input_filename\n    dataset = CsvDataset(\n        input_filename,\n        preprocess_fn,\n        img_key=args.csv_img_key,\n        caption_key=args.csv_caption_key,\n        sep=args.csv_separator,\n        tokenizer=tokenizer\n    )\n    num_samples = len(dataset)\n    sampler = DistributedSampler(dataset) if args.distributed and is_train else None\n    shuffle = is_train and sampler is None\n\n    dataloader = DataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        shuffle=shuffle,\n        num_workers=args.workers,\n        pin_memory=True,\n        sampler=sampler,\n        drop_last=is_train,\n    )\n    dataloader.num_samples = num_samples\n    dataloader.num_batches = len(dataloader)\n\n    return DataInfo(dataloader, sampler)\n\n\nclass SyntheticDataset(Dataset):\n\n    def __init__(\n            self,\n            transform=None,\n            image_size=(224, 224),\n            caption=\"Dummy caption\",\n            dataset_size=100,\n            tokenizer=None,\n    ):\n        self.transform = transform\n        self.image_size = image_size\n        self.caption = caption\n        self.image = Image.new('RGB', image_size)\n        self.dataset_size = dataset_size\n\n        self.preprocess_txt = lambda text: tokenizer(text)[0]\n\n    def __len__(self):\n        return self.dataset_size\n\n    def __getitem__(self, idx):\n        if self.transform is not None:\n            image = self.transform(self.image)\n        return image, self.preprocess_txt(self.caption)\n\n\ndef get_synthetic_dataset(args, preprocess_fn, is_train, epoch=0, tokenizer=None):\n    image_size = preprocess_fn.transforms[0].size\n    dataset = SyntheticDataset(\n        transform=preprocess_fn, image_size=image_size, dataset_size=args.train_num_samples, tokenizer=tokenizer)\n    num_samples = len(dataset)\n    sampler = DistributedSampler(dataset) if args.distributed and is_train else None\n    shuffle = is_train and sampler is None\n\n    dataloader = DataLoader(\n        dataset,\n        batch_size=args.batch_size,\n        shuffle=shuffle,\n        num_workers=args.workers,\n        pin_memory=True,\n        sampler=sampler,\n        drop_last=is_train,\n    )\n    dataloader.num_samples = num_samples\n    dataloader.num_batches = len(dataloader)\n\n    return DataInfo(dataloader, sampler)\n\n\ndef get_dataset_fn(data_path, dataset_type):\n    if dataset_type == \"webdataset\":\n        return get_wds_dataset\n    elif dataset_type == \"csv\":\n        return get_csv_dataset\n    elif dataset_type == \"synthetic\":\n        return get_synthetic_dataset\n    elif dataset_type == \"auto\":\n        ext = data_path.split('.')[-1]\n        if ext in ['csv', 'tsv']:\n            return get_csv_dataset\n        elif ext in ['tar']:\n            return get_wds_dataset\n        else:\n            raise ValueError(\n                f\"Tried to figure out dataset type, but failed for extension {ext}.\")\n    else:\n        raise ValueError(f\"Unsupported dataset type: {dataset_type}\")\n    \n\ndef get_data(args, preprocess_fns, epoch=0, tokenizer=None):\n    preprocess_train, preprocess_val = preprocess_fns\n    data = {}\n\n    if args.train_data or args.dataset_type == \"synthetic\":\n        data[\"train\"] = get_dataset_fn(args.train_data, args.dataset_type)(\n            args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)\n\n    if args.val_data:\n        data[\"val\"] = get_dataset_fn(args.val_data, args.dataset_type)(\n            args, preprocess_val, is_train=False, tokenizer=tokenizer)\n\n    if args.imagenet_val is not None:\n        data[\"imagenet-val\"] = get_imagenet(args, preprocess_fns, \"val\")\n\n    if args.imagenet_v2 is not None:\n        data[\"imagenet-v2\"] = get_imagenet(args, preprocess_fns, \"v2\")\n\n    return data\n"
  },
  {
    "path": "inf_clip/train/engine.py",
    "content": "import json\nimport logging\nimport math\nimport os\nimport time\nfrom contextlib import nullcontext\n\nimport numpy as np\nimport pynvml\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint as torch_checkpoint\nimport torch.distributed as dist\nfrom torch.nn.parallel.distributed import DistributedDataParallel\nfrom tqdm import tqdm\n\nfrom .utils import get_autocast, is_master\nfrom inf_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \\\n    IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES, CLIP, CustomTextCLIP\nfrom inf_clip.models.loss import ClipLoss\n\ntry:\n    import wandb\nexcept ImportError:\n    wandb = None\n\n\ndef accuracy(output, target, topk=(1,)):\n    pred = output.topk(max(topk), 1, True, True)[1].t()\n    correct = pred.eq(target.view(1, -1).expand_as(pred))\n    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]\n\n\ndef get_clip_metrics(image_features, text_features, logit_scale):\n    metrics = {}\n    logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()\n    logits_per_text = logits_per_image.t().detach().cpu()\n\n    logits = {\"image_to_text\": logits_per_image, \"text_to_image\": logits_per_text}\n    ground_truth = torch.arange(len(text_features)).view(-1, 1)\n\n    for name, logit in logits.items():\n        ranking = torch.argsort(logit, descending=True)\n        preds = torch.where(ranking == ground_truth)[1]\n        preds = preds.detach().cpu().numpy()\n        metrics[f\"{name}_mean_rank\"] = preds.mean() + 1\n        metrics[f\"{name}_median_rank\"] = np.floor(np.median(preds)) + 1\n        for k in [1, 5, 10]:\n            metrics[f\"{name}_R@{k}\"] = np.mean(preds < k)\n\n    return metrics\n\n\ndef maybe_compute_generative_loss(model_out):\n    if \"logits\" in model_out and \"labels\" in model_out:\n        token_logits = model_out[\"logits\"]\n        token_labels = model_out[\"labels\"]\n        return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)\n\n\ndef get_memory():\n    pynvml.nvmlInit()\n    # NOTE: 0 denotes GPU index.\n    handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)\n\n    return meminfo.used / 1024**3\n\n\ndef seconds_to_hms(seconds):\n    hours, remainder = divmod(seconds, 3600)\n    minutes, seconds = divmod(remainder, 60)\n    hours = int(hours); minutes = int(minutes); seconds = int(seconds)\n    return f\"{hours}:{minutes:02d}:{seconds:02d}\"\n\n\ndef cal_grad_norm(model):\n    total_norm = 0\n    for p in model.parameters():\n        if p.grad is not None:\n            param_norm = p.grad.data.norm(2)\n            total_norm += param_norm.item() ** 2\n    total_norm = total_norm ** 0.5\n    return total_norm\n\n\ndef assign_learning_rate(optimizer, new_lr):\n    for param_group in optimizer.param_groups:\n        param_group[\"lr\"] = new_lr\n\n\ndef _warmup_lr(base_lr, warmup_length, step):\n    return base_lr * (step + 1) / warmup_length\n\n\ndef const_lr(optimizer, base_lr, warmup_length, steps):\n    def _lr_adjuster(step):\n        if step < warmup_length:\n            lr = _warmup_lr(base_lr, warmup_length, step)\n        else:\n            lr = base_lr\n        assign_learning_rate(optimizer, lr)\n        return lr\n    return _lr_adjuster\n\n\ndef const_lr_cooldown(optimizer, base_lr, warmup_length, steps, cooldown_steps, cooldown_power=1.0, cooldown_end_lr=0.):\n    def _lr_adjuster(step):\n        start_cooldown_step = steps - cooldown_steps\n        if step < warmup_length:\n            lr = _warmup_lr(base_lr, warmup_length, step)\n        else:\n            if step < start_cooldown_step:\n                lr = base_lr\n            else:\n                e = step - start_cooldown_step\n                es = steps - start_cooldown_step\n                # linear decay if power == 1; polynomial decay otherwise;\n                decay = (1 - (e/es)) ** cooldown_power\n                lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr\n        assign_learning_rate(optimizer, lr)\n        return lr\n    return _lr_adjuster\n\n\ndef cosine_lr(optimizer, base_lr, warmup_length, steps):\n    def _lr_adjuster(step):\n        if step < warmup_length:\n            lr = _warmup_lr(base_lr, warmup_length, step)\n        else:\n            e = step - warmup_length\n            es = steps - warmup_length\n            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr\n        assign_learning_rate(optimizer, lr)\n        return lr\n    return _lr_adjuster\n\n\ndef postprocess_clip_output(model_out):\n    return {\n        \"image_features\": model_out[0],\n        \"text_features\": model_out[1],\n        \"logit_scale\": model_out[2]\n    }\n\n\ndef unwrap_model(model):\n    if hasattr(model, 'module'):\n        return model.module\n    else:\n        return model\n\n\ndef backward(total_loss, scaler):\n    if scaler is not None:\n        scaler.scale(total_loss).backward()\n    else:\n        total_loss.backward()\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\nclass GradientAccum:\n\n    def __init__(self, model, loss, scaler, autocast, input_dtype, device):\n        self.model = model\n        self.loss = loss\n        self.scaler = scaler\n        self.autocast = autocast\n        self.input_dtype = input_dtype\n        self.device = device\n\n        self.logit_scale = unwrap_model(model).logit_scale\n        self.arch_type = unwrap_model(model).arch_type\n\n        self.accum_freq = 0\n        self.accum_cpu_states = []\n        self.accum_gpu_devices_states = []\n        self.accum_images = []\n        self.accum_texts = []\n        self.accum_image_features = []\n        self.accum_text_features = []\n\n        self.rank = dist.get_rank()\n\n    def clear(self):\n        self.accum_image_features.clear()\n        self.accum_text_features.clear()\n        torch.cuda.empty_cache()\n\n    def clear_state(self):\n        self.accum_images.clear()\n        self.accum_texts.clear()\n        self.accum_cpu_states.clear()\n        self.accum_gpu_devices_states.clear()\n        self.accum_freq = 0\n\n    @torch.no_grad()\n    def accum_inference(self, images, texts):\n        images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)\n        texts = texts.to(device=self.device, non_blocking=True)\n        # First, cache the features without any gradient tracking.\n        with self.autocast():\n            # collect rand states\n            self.accum_cpu_states.append(torch.get_rng_state())\n            self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts]))\n\n            model_out = self.model(images, texts)\n\n            self.accum_image_features.append(model_out[\"image_features\"].detach().clone())\n            self.accum_text_features.append(model_out[\"text_features\"].detach().clone())\n\n        if self.arch_type == \"lit\":\n            # lit\n            accum_image = model_out[\"image_trunk_features\"].detach().clone()\n        else:\n            accum_image = images.detach().clone()\n        accum_text = texts.detach().clone()\n        \n        # offloading\n        accum_image = accum_image.cpu()\n        accum_text  = accum_text.cpu()\n\n        self.accum_images.append(accum_image)\n        self.accum_texts.append(accum_texts)\n\n        self.accum_freq += 1\n\n    def accum_forward_backward(self):\n        accum_losses = {\"loss\": 0.0}\n        for j in range(self.accum_freq):\n            images = self.accum_images[j]\n            texts  = self.accum_texts[j]\n            \n            # refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235\n            # DDP will sync gradients across GPUs, which is no need except the last batch.\n            sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext\n\n            with torch.random.fork_rng(devices=(device,)), sync_context():\n                # setting random states\n                torch.set_rng_state(self.accum_cpu_states[j])\n                torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j])\n\n                with autocast():\n                    model_out = self.model(images, texts)\n\n                    inputs_no_accum = {}\n                    inputs_no_accum[\"logit_scale\"] = logit_scale = model_out.pop(\"logit_scale\")\n                    if \"logit_bias\" in model_out:\n                        inputs_no_accum[\"logit_bias\"] = model_out.pop(\"logit_bias\")\n\n                    inputs = {}\n                    inputs[\"image_features\"] = torch.cat(self.accum_image_features[:j] + [model_out[\"image_features\"]] + self.accum_image_features[j + 1:])\n                    inputs[\"text_features\"] = torch.cat(self.accum_text_features[:j] + [model_out[\"text_features\"]] + self.accum_text_features[j + 1:])\n\n                    losses = self.loss(**inputs, **inputs_no_accum)\n                    show_loss = losses.pop(\"show_loss\")\n                    total_loss = sum(losses.values())\n                    losses[\"loss\"] = show_loss\n\n                    del inputs\n                    del inputs_no_accum\n\n                backward(total_loss, scaler)\n                accum_losses[\"loss\"] += losses[\"loss\"]\n\n        accum_losses[\"loss\"] /= accum_freq\n\n        self.clear()\n        self.clear_state()\n\n        return accum_losses\n\n\nclass GradientCache:\n\n    def __init__(self, model, loss, scaler, autocast, input_dtype, device):\n        self.model = model\n        self.loss = loss\n        self.scaler = scaler\n        self.autocast = autocast\n        self.input_dtype = input_dtype\n        self.device = device\n\n        self.logit_scale = unwrap_model(model).logit_scale\n        self.arch_type = unwrap_model(model).arch_type\n\n        self.accum_freq = 0\n        self.accum_cpu_states = []\n        self.accum_gpu_devices_states = []\n        self.accum_images = []\n        self.accum_texts = []\n        self.accum_image_features = []\n        self.accum_text_features = []\n\n        self.rank = dist.get_rank()\n\n    def clear(self):\n        self.accum_image_features.clear()\n        self.accum_text_features.clear()\n        torch.cuda.empty_cache()\n\n    def clear_state(self):\n        self.accum_images.clear()\n        self.accum_texts.clear()\n        self.accum_cpu_states.clear()\n        self.accum_gpu_devices_states.clear()\n        self.accum_freq = 0\n\n    def forward_backward(self, images, texts):\n        images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)\n        texts  = texts.to(device=self.device, non_blocking=True)\n        with self.autocast():\n            model_out = self.model(image=images, text=texts)\n\n        model_out.pop(\"image_trunk_features\", None)\n\n        losses = self.loss(**model_out)\n        show_loss = losses.pop(\"show_loss\")\n        total_loss = sum(losses.values())\n        losses[\"loss\"] = show_loss\n\n        backward(total_loss, self.scaler)\n\n        return losses\n\n    @torch.no_grad()\n    def accum_inference(self, images, texts):\n        images = images.to(device=self.device, dtype=self.input_dtype, non_blocking=True)\n        texts = texts.to(device=self.device, non_blocking=True)\n        # First, cache the features without any gradient tracking.\n        with self.autocast():\n            # collect rand states\n            self.accum_cpu_states.append(torch.get_rng_state())\n            self.accum_gpu_devices_states.append(torch_checkpoint.get_device_states(*[images, texts]))\n\n            model_out = self.model(image=images, text=texts)\n\n            self.accum_image_features.append(model_out[\"image_features\"])\n            self.accum_text_features.append(model_out[\"text_features\"])\n\n        # Speed analysis of detach().clone(): https://stackoverflow.com/questions/55266154/pytorch-preferred-way-to-copy-a-tensor\n        if self.arch_type == \"lit\":\n            # lit\n            accum_image = model_out[\"image_trunk_features\"].detach().clone()\n        else:\n            accum_image = images.detach().clone()\n        accum_text = texts.detach().clone()\n        \n        # offloading\n        # accum_image = accum_image.cpu()\n        # accum_text  = accum_text.cpu()\n\n        self.accum_images.append(accum_image)\n        self.accum_texts.append(accum_text)\n\n        self.accum_freq += 1\n\n    def accum_forward_backward(self):\n        accum_qs = [x.requires_grad_() for x in self.accum_image_features]; qs = torch.cat(accum_qs, dim=0)\n        accum_ks = [x.requires_grad_() for x in self.accum_text_features]; ks = torch.cat(accum_ks, dim=0)\n        ls = self.logit_scale.exp().detach().clone().requires_grad_()\n\n        losses = self.loss(image_features=qs, text_features=ks, logit_scale=ls)\n        show_loss = losses.pop(\"show_loss\")\n        total_loss = sum(losses.values())\n        losses[\"loss\"] = show_loss\n\n        backward(total_loss, self.scaler)\n\n        accum_q_grads = [q.grad for q in accum_qs]\n        accum_k_grads = [k.grad for k in accum_ks]\n        l_grad = ls.grad\n\n        del accum_qs, accum_ks\n        del qs, ks, ls\n\n        # Clean trash memory from loss calculation or inference\n        self.clear()\n\n        for j in range(self.accum_freq):\n            images = self.accum_images[j]\n            texts = self.accum_texts[j]\n\n            # refer to the implementation of Gradient Cache: https://github.com/luyug/GradCache/blob/906f03835fbc183132a9db32612a9e8f180ca3b4/src/grad_cache/grad_cache.py#L235\n            # DDP will sync gradients across GPUs, which is no need except the last batch.\n            sync_context = self.model.no_sync if j != self.accum_freq - 1 else nullcontext\n\n            with torch.random.fork_rng(devices=(self.device, )), sync_context():\n                # setting random states\n                torch.set_rng_state(self.accum_cpu_states[j])\n                torch_checkpoint.set_device_states(*self.accum_gpu_devices_states[j])\n\n                with self.autocast():\n                    if self.arch_type == \"lit\":\n                        model_out = self.model(images, texts, project_only=True)\n                    else:\n                        model_out = self.model(images, texts)\n\n                q = model_out[\"image_features\"]\n                k = model_out[\"text_features\"]\n                l = model_out[\"logit_scale\"]\n\n                _loss = torch.dot(q.flatten(), accum_q_grads[j].flatten()) + \\\n                        torch.dot(k.flatten(), accum_k_grads[j].flatten()) + \\\n                        l * l_grad / self.accum_freq\n\n                _loss.backward()\n\n        self.clear_state()\n\n        return losses\n\n\ndef train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):\n    device = torch.device(args.device)\n    autocast = get_autocast(args.precision)\n    input_dtype = get_input_dtype(args.precision)\n\n    model.train()\n\n    data['train'].set_epoch(epoch)  # set epoch in process safe manner via sampler or shared_epoch\n    dataloader = data['train'].dataloader\n    num_batches_per_epoch = dataloader.num_batches // args.accum_freq\n    sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))\n\n    runner = GradientCache(model, loss, scaler, autocast, input_dtype, device)\n\n    rest_iters = num_batches_per_epoch * (args.epochs - epoch)\n\n    losses_m = {}\n    global_batch_time_m = AverageMeter()\n    batch_time_m = AverageMeter()\n    data_time_m  = AverageMeter()\n    end = time.time()\n    for i, batch in enumerate(dataloader):\n        i_accum = i // args.accum_freq\n        step = num_batches_per_epoch * epoch + i_accum\n\n        if not args.skip_scheduler:\n            scheduler(step)\n\n        images, texts = batch\n\n        data_time_m.update(time.time() - end)\n        optimizer.zero_grad()\n\n        if args.accum_freq == 1:\n            losses = runner.forward_backward(images, texts)\n        else:\n            runner.accum_inference(images, texts)\n\n            # If (i + 1) % accum_freq is not zero, move on to the next batch.\n            if ((i + 1) % args.accum_freq) > 0:\n                # FIXME this makes data time logging unreliable when accumulating\n                continue\n\n            # Now, ready to take gradients for the last accum_freq batches.\n            # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.\n            # Call backwards each time, but only step optimizer at the end.\n            losses = runner.accum_forward_backward()\n\n        if scaler is not None:\n            if args.horovod:\n                optimizer.synchronize()\n                scaler.unscale_(optimizer)\n                if args.grad_clip_norm is not None:\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)\n                with optimizer.skip_synchronize():\n                    scaler.step(optimizer)\n            else:\n                if args.grad_clip_norm is not None:\n                    scaler.unscale_(optimizer)\n                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)\n                scaler.step(optimizer)\n            scaler.update()\n        else:\n            if args.grad_clip_norm is not None:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)\n            optimizer.step()\n\n        # Note: we clamp to 4.6052 = ln(100), as in the original paper.\n        with torch.no_grad():\n            unwrap_model(model).logit_scale.clamp_(0, math.log(100))\n\n        global_batch_time_m.update(time.time() - end)\n        batch_time_m.update(time.time() - end)\n        end = time.time()\n        batch_count = i_accum + 1\n        if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):\n            batch_size = len(images)\n            num_samples = batch_count * batch_size * args.accum_freq * args.world_size\n            samples_per_epoch = dataloader.num_samples\n            percent_complete = batch_count / num_batches_per_epoch\n\n            # NOTE loss is coarsely sampled, just master node and per log update\n            for key, val in losses.items():\n                if key not in losses_m:\n                    losses_m[key] = AverageMeter()\n                losses_m[key].update(val.item(), batch_size)\n\n            logit_scale_scalar = unwrap_model(model).logit_scale.exp().item()\n            loss_log = \" \".join(\n                [\n                    f\"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})\" \n                    for loss_name, loss_m in losses_m.items()\n                ]\n            )\n            samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val\n            samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val\n            grad_norm = cal_grad_norm(model.module)\n\n            running_time = seconds_to_hms(time.time() - start_timestamp)\n            rest_iters = rest_iters - 1\n            whole_time   = seconds_to_hms(time.time() - start_timestamp + rest_iters * global_batch_time_m.avg)\n            logging.info(\n                f\"{running_time}<{whole_time} \"\n                f\"Epoch: {epoch + percent_complete:.2f} \"\n                f\"Data (t): {data_time_m.avg:.3f} \"\n                f\"Batch (t): {batch_time_m.avg:.3f} \"\n                f\"LR: {optimizer.param_groups[0]['lr']:5f} \"\n                f\"Grad Norm: {grad_norm:.3f} \"\n                f\"Logit Scale: {logit_scale_scalar:.3f} \" + loss_log + \" \"\n                f\"Memory: {get_memory():.2f}GB \"\n            )\n\n            # Save train loss / etc. Using non avg meter values as loggers have their own smoothing\n            log_data = {\n                \"data_time\": data_time_m.val,\n                \"batch_time\": batch_time_m.val,\n                \"samples_per_second\": samples_per_second,\n                \"samples_per_second_per_gpu\": samples_per_second_per_gpu,\n                \"scale\": logit_scale_scalar,\n                \"grad_norm\": grad_norm,\n                \"lr\": optimizer.param_groups[0][\"lr\"]\n            }\n            log_data.update({name:val.val for name,val in losses_m.items()})\n\n            log_data = {\"train/\" + name: val for name, val in log_data.items()}\n\n            if tb_writer is not None:\n                for name, val in log_data.items():\n                    tb_writer.add_scalar(name, val, step)\n\n            if args.wandb:\n                assert wandb is not None, 'Please install wandb.'\n                log_data['step'] = step  # for backwards compatibility\n                wandb.log(log_data, step=step)\n\n            # resetting batch / data time meters per log window\n            batch_time_m.reset()\n            data_time_m.reset()\n    # end for\n\n\ndef evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):\n    metrics = {}\n    if not is_master(args):\n        return metrics\n    device = torch.device(args.device)\n    model.eval()\n\n    zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer)\n    metrics.update(zero_shot_metrics)\n\n    autocast = get_autocast(args.precision)\n    input_dtype = get_input_dtype(args.precision)\n\n    if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):\n        dataloader = data['val'].dataloader\n        num_samples = 0\n        samples_per_val = dataloader.num_samples\n\n        # FIXME this does not scale past small eval datasets\n        # all_image_features @ all_text_features will blow up memory and compute very quickly\n        cumulative_loss = 0.0\n        cumulative_gen_loss = 0.0\n        all_image_features, all_text_features = [], []\n        with torch.inference_mode():\n            for i, batch in enumerate(dataloader):\n                images, texts = batch\n                images = images.to(device=device, dtype=input_dtype, non_blocking=True)\n                texts = texts.to(device=device, non_blocking=True)\n\n                with autocast():\n                    model_out = model(images, texts)\n                    image_features = model_out[\"image_features\"]\n                    text_features = model_out[\"text_features\"]\n                    logit_scale = model_out[\"logit_scale\"]\n                    # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly\n                    # however, system RAM is easily exceeded and compute time becomes problematic\n                    all_image_features.append(image_features.cpu())\n                    all_text_features.append(text_features.cpu())\n                    logit_scale = logit_scale.mean()\n                    logits_per_image = logit_scale * image_features @ text_features.t()\n                    logits_per_text = logits_per_image.t()\n\n                    batch_size = images.shape[0]\n                    labels = torch.arange(batch_size, device=device).long()\n                    total_loss = (\n                        F.cross_entropy(logits_per_image, labels) +\n                        F.cross_entropy(logits_per_text, labels)\n                    ) / 2\n\n                    gen_loss = maybe_compute_generative_loss(model_out)\n\n                cumulative_loss += total_loss * batch_size\n                num_samples += batch_size\n                if is_master(args) and (i % 100) == 0:\n                    logging.info(\n                        f\"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\\t\"\n                        f\"Clip Loss: {cumulative_loss / num_samples:.6f}\\t\")\n\n                    if gen_loss is not None:\n                        cumulative_gen_loss += gen_loss * batch_size\n                        logging.info(\n                            f\"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\\t\")\n\n            val_metrics = get_clip_metrics(\n                image_features=torch.cat(all_image_features),\n                text_features=torch.cat(all_text_features),\n                logit_scale=logit_scale.cpu(),\n            )\n            loss = cumulative_loss / num_samples\n            metrics.update(\n                {**val_metrics, \"clip_val_loss\": loss.item(), \"epoch\": epoch, \"num_samples\": num_samples}\n            )\n            if gen_loss is not None:\n                gen_loss = cumulative_gen_loss / num_samples\n                metrics.update({\"val_generative_loss\": gen_loss.item()})\n\n    if not metrics:\n        return metrics\n\n    logging.info(\n        f\"Eval Epoch: {epoch} \"\n        + \"\\t\".join([f\"{k}: {round(v, 4):.4f}\" for k, v in metrics.items()])\n    )\n\n    log_data = {\"val/\" + name: val for name, val in metrics.items()}\n\n    if args.save_logs:\n        if tb_writer is not None:\n            for name, val in log_data.items():\n                tb_writer.add_scalar(name, val, epoch)\n\n        with open(os.path.join(args.checkpoint_path, \"results.jsonl\"), \"a+\") as f:\n            f.write(json.dumps(metrics))\n            f.write(\"\\n\")\n\n    if args.wandb:\n        assert wandb is not None, 'Please install wandb.'\n        if 'train' in data:\n            dataloader = data['train'].dataloader\n            num_batches_per_epoch = dataloader.num_batches // args.accum_freq\n            step = num_batches_per_epoch * epoch\n        else:\n            step = None\n        log_data['epoch'] = epoch\n        wandb.log(log_data, step=step)\n\n    return metrics\n\n\ndef zero_shot_run(model, classifier, dataloader, args):\n    autocast = get_autocast(args.precision)\n    input_dtype = get_input_dtype(args.precision)\n\n    with torch.inference_mode():\n        top1, top5, n = 0., 0., 0.\n        for images, target in tqdm(dataloader, unit_scale=args.batch_size):\n            images = images.to(device=args.device, dtype=input_dtype)\n            target = target.to(args.device)\n\n            with autocast():\n                # predict\n                output = model(image=images)\n                image_features = output['image_features'] if isinstance(output, dict) else output[0]\n                logits = 100. * image_features @ classifier\n\n            # measure accuracy\n            acc1, acc5 = accuracy(logits, target, topk=(1, 5))\n            top1 += acc1\n            top5 += acc5\n            n += images.size(0)\n\n    top1 = (top1 / n)\n    top5 = (top5 / n)\n    return top1, top5\n\n\ndef zero_shot_eval(model, data, epoch, args, tokenizer=None):\n    if 'imagenet-val' not in data and 'imagenet-v2' not in data:\n        return {}\n    if args.zeroshot_frequency == 0:\n        return {}\n    if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:\n        return {}\n    if args.distributed and not args.horovod:\n        model = model.module\n\n    logging.info('Starting zero-shot imagenet.')\n    if tokenizer is None:\n        tokenizer = get_tokenizer(args.model)\n\n    logging.info('Building zero-shot classifier')\n    autocast = get_autocast(args.precision)\n    with autocast():\n        classifier = build_zero_shot_classifier(\n            model,\n            tokenizer=tokenizer,\n            classnames=IMAGENET_CLASSNAMES,\n            templates=OPENAI_IMAGENET_TEMPLATES,\n            num_classes_per_batch=10,\n            device=args.device,\n            use_tqdm=True,\n        )\n\n    logging.info('Using classifier')\n    results = {}\n    if 'imagenet-val' in data:\n        top1, top5 = zero_shot_run(model, classifier, data['imagenet-val'].dataloader, args)\n        results['imagenet-zeroshot-val-top1'] = top1\n        results['imagenet-zeroshot-val-top5'] = top5\n    if 'imagenet-v2' in data:\n        top1, top5 = zero_shot_run(model, classifier, data['imagenet-v2'].dataloader, args)\n        results['imagenetv2-zeroshot-val-top1'] = top1\n        results['imagenetv2-zeroshot-val-top5'] = top5\n\n    logging.info('Finished zero-shot imagenet.')\n\n    return results\n"
  },
  {
    "path": "inf_clip/train/main.py",
    "content": "import glob\nimport logging\nimport os\nimport re\nimport subprocess\nimport sys\nimport random\nimport time\nfrom functools import partial\n\nimport numpy as np\nimport torch\nfrom torch import optim\nfrom torch.cuda.amp import GradScaler\n\ntry:\n    import wandb\nexcept ImportError:\n    wandb = None\n\ntry:\n    import torch.utils.tensorboard as tensorboard\nexcept ImportError:\n    tensorboard = None\n\ntry:\n    import horovod.torch as hvd\nexcept ImportError:\n    hvd = None\n\nfrom inf_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss\nfrom inf_clip.train.data import get_data\nfrom inf_clip.train.params import parse_args\nfrom inf_clip.train.optims import ScalingViTAdafactor, Lion\nfrom inf_clip.train.engine import cosine_lr, const_lr, const_lr_cooldown, train_one_epoch, evaluate\nfrom inf_clip.train.utils import (setup_logging, pt_load, check_exists, start_sync_process, remote_sync, is_master, init_distributed_device, broadcast_object)\n\n\nLATEST_CHECKPOINT_NAME = \"epoch_latest.pt\"\n\n\ndef random_seed(seed=42, rank=0):\n    random.seed(seed + rank)\n    np.random.seed(seed + rank)\n    torch.manual_seed(seed + rank)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef natural_key(string_):\n    \"\"\"See http://www.codinghorror.com/blog/archives/001018.html\"\"\"\n    return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\n\n\ndef copy_codebase(args):\n    from shutil import copytree, ignore_patterns\n    new_code_path = os.path.join(args.log_dir, args.name, \"code\")\n    if os.path.exists(new_code_path):\n        print(\n            f\"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment.\"\n        )\n        return -1\n    print(f\"Copying codebase to {new_code_path}\")\n    current_code_path = os.path.realpath(__file__)\n    for _ in range(3):\n        current_code_path = os.path.dirname(current_code_path)\n    copytree(current_code_path, new_code_path, ignore=ignore_patterns('log', 'logs', 'wandb'))\n    print(\"Done copying code.\")\n    return 1\n\n\ndef prepare_logging(args):\n    # get the name of the experiments\n    if args.name is None:\n        from datetime import datetime\n        # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?\n        model_name_safe = args.model.replace('/', '-')\n        date_str = datetime.now().strftime(\"%Y_%m_%d-%H_%M_%S\")\n        if args.distributed:\n            # sync date_str from master to all ranks\n            date_str = broadcast_object(args, date_str)\n        args.name = '-'.join([\n            date_str,\n            f\"model_{model_name_safe}\",\n            f\"lr_{args.lr}\",\n            f\"b_{args.batch_size}\",\n            f\"j_{args.workers}\",\n            f\"p_{args.precision}\",\n        ])\n\n    resume_latest = args.resume == 'latest'\n    log_base_path = os.path.join(args.log_dir, args.name)\n\n    args.log_path = None\n    if is_master(args, local=args.log_local):\n        os.makedirs(log_base_path, exist_ok=True)\n        log_filename = f'out-{args.rank}' if args.log_local else 'out.log'\n        args.log_path = os.path.join(log_base_path, log_filename)\n        # if os.path.exists(args.log_path) and not resume_latest:\n        #     print(\n        #         \"Error. Experiment already exists. Use --name {} to specify a new experiment.\"\n        #     )\n        #     return -1\n\n    # Setup text logger\n    args.log_level = logging.DEBUG if args.debug else logging.INFO\n    setup_logging(args.log_path, args.log_level)\n\n    # Setup wandb, tensorboard, checkpoint logging\n    args.wandb = 'wandb' in args.report_to or 'all' in args.report_to\n    args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to\n    args.checkpoint_path = os.path.join(log_base_path, \"checkpoints\")\n    if is_master(args):\n        args.tensorboard_path = os.path.join(log_base_path, \"tensorboard\") if args.tensorboard else ''\n        for dirname in [args.tensorboard_path, args.checkpoint_path]:\n            if dirname:\n                os.makedirs(dirname, exist_ok=True)\n    else:\n        args.tensorboard_path = ''\n\n    if args.copy_codebase:\n        copy_codebase(args)\n\n    return log_base_path, resume_latest\n\n\ndef get_latest_checkpoint(path: str, remote : bool):\n    # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders\n    if remote:\n        result = subprocess.run([\"aws\", \"s3\", \"ls\", path + \"/\"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n        print(result)\n        if result.returncode == 1:\n            return None\n        checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\\n')[:-1]]\n    else:\n        checkpoints = glob.glob(path + '**/*.pt', recursive=True)\n    if checkpoints:\n        checkpoints = sorted(checkpoints, key=natural_key)\n        return checkpoints[-1]\n    return None\n\n\ndef prepare_resuming(args):\n    resume_from = None\n    checkpoint_path = args.checkpoint_path\n    # If using remote_sync, need to check the remote instead of the local checkpoints folder.\n    if args.remote_sync is not None:\n        checkpoint_path = os.path.join(args.remote_sync, args.name, \"checkpoints\")\n        if args.save_most_recent:\n            print('Error. Cannot use save-most-recent with remote_sync and resume latest.')\n            return -1\n        if args.remote_sync_protocol != 's3':\n            print('Error. Sync protocol not supported when using resume latest.')\n            return -1\n    if is_master(args):\n        # Checking for existing checkpoint via master rank only. It is possible for\n        # different rank processes to see different files if a shared file-system is under\n        # stress, however it's very difficult to fully work around such situations.\n        if args.save_most_recent:\n            # if --save-most-recent flag is set, look for latest at a fixed filename\n            resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME)\n            if not os.path.exists(resume_from):\n                # If no latest checkpoint has been saved yet, don't try to resume\n                resume_from = None\n        else:\n            # otherwise, list checkpoint dir contents and pick the newest checkpoint\n            resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None)\n        if resume_from:\n            logging.info(f'Found latest resume checkpoint at {resume_from}.')\n        else:\n            logging.info(f'No latest resume checkpoint found in {checkpoint_path}.')\n    if args.distributed:\n        # sync found checkpoint path to all ranks\n        resume_from = broadcast_object(args, resume_from)\n    args.resume = resume_from\n\n\ndef prepare_remote_sync(args):\n    # start the sync proces if remote-sync is not None\n    remote_sync_process = None\n    if is_master(args) and args.remote_sync is not None:\n        # first make sure it works\n        result = remote_sync(\n            os.path.join(args.log_dir, args.name), \n            os.path.join(args.remote_sync, args.name), \n            args.remote_sync_protocol\n        )\n        if result:\n            logging.info('remote sync successful.')\n        else:\n            logging.info('Error: remote sync failed. Exiting.')\n            return -1\n        # if all looks good, start a process to do this every args.remote_sync_frequency seconds\n        remote_sync_process = start_sync_process(\n            args.remote_sync_frequency,\n            os.path.join(args.log_dir, args.name), \n            os.path.join(args.remote_sync, args.name), \n            args.remote_sync_protocol\n        )\n        remote_sync_process.start()\n\n    return remote_sync_process\n\n\ndef prepare_model(args, device):\n    dist_model = None\n    args.distill = args.distill_model is not None and args.distill_pretrained is not None\n    if args.distill:\n        #FIXME: support distillation with grad accum.\n        assert args.accum_freq == 1\n        #FIXME: support distillation with coca.\n        assert 'coca' not in args.model.lower()\n\n    if isinstance(args.force_image_size, (tuple, list)) and len(args.force_image_size) == 1:\n        # arg is nargs, single (square) image size list -> int\n        args.force_image_size = args.force_image_size[0]\n    random_seed(args.seed, 0)\n    model_kwargs = {}\n    if args.siglip:\n        model_kwargs['init_logit_scale'] = np.log(10)  # different from CLIP\n        model_kwargs['init_logit_bias'] = -10\n    model, preprocess_train, preprocess_val = create_model_and_transforms(\n        args.model,\n        args.pretrained,\n        precision=args.precision,\n        device=device,\n        jit=args.torchscript,\n        force_quick_gelu=args.force_quick_gelu,\n        force_custom_text=args.force_custom_text,\n        force_patch_dropout=args.force_patch_dropout,\n        force_image_size=args.force_image_size,\n        image_mean=args.image_mean,\n        image_std=args.image_std,\n        image_interpolation=args.image_interpolation,\n        image_resize_mode=args.image_resize_mode,  # only effective for inference\n        aug_cfg=args.aug_cfg,\n        pretrained_image=args.pretrained_image,\n        output_dict=True,\n        **model_kwargs,\n    )\n    if args.distill:\n        # FIXME: currently assumes the model you're distilling from has the same tokenizer & transforms.\n        dist_model, _, _ = create_model_and_transforms(\n            args.distill_model, \n            args.distill_pretrained,\n            device=device,\n            precision=args.precision,\n            output_dict=True,\n        )\n    if args.use_bnb_linear is not None:\n        print('=> using a layer from bitsandbytes.\\n'\n              '   this is an experimental feature which requires two extra pip installs\\n'\n              '   pip install bitsandbytes triton'\n              '   please make sure to use triton 2.0.0')\n        import bitsandbytes as bnb\n        from open_clip.utils import replace_linear\n        print(f'=> replacing linear layers with {args.use_bnb_linear}')\n        linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear)\n        replace_linear(model, linear_replacement_cls)\n        model = model.to(device)\n\n    random_seed(args.seed, args.rank)\n\n    if args.trace:\n        model = trace_model(model, batch_size=args.batch_size, device=device)\n\n    if args.lock_image:\n        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n        model.lock_image_tower(\n            unlocked_groups=args.lock_image_unlocked_groups,\n            freeze_bn_stats=args.lock_image_freeze_bn_stats)\n    if args.lock_text:\n        model.lock_text_tower(\n            unlocked_layers=args.lock_text_unlocked_layers,\n            freeze_layer_norm=args.lock_text_freeze_layer_norm)\n\n    if args.grad_checkpointing:\n        model.set_grad_checkpointing()\n\n    if args.distributed and not args.horovod:\n        if args.use_bn_sync:\n            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n        ddp_args = {}\n        if args.ddp_static_graph:\n            # this doesn't exist in older PyTorch, arg only added if enabled\n            ddp_args['static_graph'] = True\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args)\n\n        if args.distill:\n            dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args)\n\n    if is_master(args):\n        logging.info(\"Model:\")\n        logging.info(f\"{str(model)}\")\n        logging.info(\"Params:\")\n        params_file = os.path.join(args.log_dir, args.name, \"params.txt\")\n        with open(params_file, \"w\") as f:\n            for name in sorted(vars(args)):\n                val = getattr(args, name)\n                logging.info(f\"  {name}: {val}\")\n                f.write(f\"{name}: {val}\\n\")\n\n    tokenizer = get_tokenizer(args.model)\n\n    return tokenizer, model, dist_model, preprocess_train, preprocess_val\n\n\ndef prepare_optimizer_scaler(args, model):\n    assert not args.trace, 'Cannot train with traced model'\n\n    exclude = lambda n, p: p.ndim < 2 or \"bn\" in n or \"ln\" in n or \"bias\" in n or 'logit_scale' in n\n    include = lambda n, p: not exclude(n, p)\n\n    named_parameters = list(model.named_parameters())\n    named_parameters = [(n, p) for n, p in named_parameters if p.requires_grad]\n    gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]\n    rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]\n\n    if args.optimizer == \"adam\":\n        optimizer = optim.AdamW(\n            [\n                {\"params\": gain_or_bias_params, \"weight_decay\": 0.},\n                {\"params\": rest_params, \"weight_decay\": args.wd},\n            ],\n            lr=args.lr,\n            betas=(args.beta1, args.beta2),\n            eps=args.eps,\n        )\n    elif args.optimizer == \"adafactor\":\n        optimizer = ScalingViTAdafactor(\n            [\n                {\"params\": gain_or_bias_params, \"weight_decay\": 0.},\n                {\"params\": rest_params, \"weight_decay\": args.wd},\n            ],\n            lr=args.lr,\n            beta1=args.beta1,\n            beta2=args.beta2,\n        )\n    elif args.optimizer == \"lion\":\n        optimizer = Lion(\n            [\n                {\"params\": gain_or_bias_params, \"weight_decay\": 0.},\n                {\"params\": rest_params, \"weight_decay\": args.wd},\n            ],\n            lr=args.lr,\n            betas=(args.beta1, args.beta2),\n        )\n        \n    # elif args.optim == \"lamb\":\n    if args.horovod:\n        optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())\n        hvd.broadcast_parameters(model.state_dict(), root_rank=0)\n        hvd.broadcast_optimizer_state(optimizer, root_rank=0)\n\n    scaler = GradScaler() if args.precision == \"amp\" else None\n\n    return optimizer, scaler\n\n\ndef prepare_scheduler(args, optimizer, num_batches):\n    scheduler = None\n    \n    total_steps = (num_batches // args.accum_freq) * args.epochs\n    if args.lr_scheduler == \"cosine\":\n        scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)\n    elif args.lr_scheduler == \"const\":\n        scheduler = const_lr(optimizer, args.lr, args.warmup, total_steps)\n    elif args.lr_scheduler == \"const-cooldown\":\n        assert args.epochs_cooldown is not None,\\\n            \"Please specify the number of cooldown epochs for this lr schedule.\"\n        cooldown_steps = (num_batches // args.accum_freq) * args.epochs_cooldown\n        scheduler = const_lr_cooldown(\n            optimizer, args.lr, args.warmup, total_steps,\n            cooldown_steps, args.lr_cooldown_power, args.lr_cooldown_end)\n    else:\n        logging.error(\n            f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.')\n        exit(1)\n\n    return scheduler\n\n\ndef main(args):\n    args = parse_args(args)\n\n    if torch.cuda.is_available():\n        # This enables tf32 on Ampere GPUs which is only 8% slower than\n        # float16 and almost as accurate as float32\n        # This was a default in pytorch until 1.12\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.benchmark = True\n        torch.backends.cudnn.deterministic = False\n\n    # fully initialize distributed device environment\n    device = init_distributed_device(args)\n\n    log_base_path, resume_latest = prepare_logging(args)\n    -1 if not resume_latest else prepare_resuming(args)\n    remote_sync_process = prepare_remote_sync(args)\n\n    if args.horovod:\n        logging.info(\n            f'Running in horovod mode with multiple processes / nodes. Device: {args.device}.'\n            f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')\n    elif args.distributed:\n        logging.info(\n            f'Running in distributed mode with multiple processes. Device: {args.device}.'\n            f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.')\n    else:\n        logging.info(f'Running with a single process. Device {args.device}.')\n\n    tokenizer, model, dist_model, preprocess_train, preprocess_val = prepare_model(args, device)\n\n    # create optimizer and scaler\n    if args.train_data or args.dataset_type == \"synthetic\":\n        optimizer, scaler = prepare_optimizer_scaler(args, model)\n\n    # optionally resume from a checkpoint\n    start_epoch = 0\n    if args.resume is not None:\n        checkpoint = pt_load(args.resume, map_location='cpu')\n        if 'epoch' in checkpoint:\n            # resuming a train checkpoint w/ epoch and optimizer state\n            start_epoch = checkpoint[\"epoch\"]\n            sd = checkpoint[\"state_dict\"]\n            if not args.distributed and next(iter(sd.items()))[0].startswith('module'):\n                sd = {k[len('module.'):]: v for k, v in sd.items()}\n            model.load_state_dict(sd)\n            if optimizer is not None:\n                optimizer.load_state_dict(checkpoint[\"optimizer\"])\n            if scaler is not None and 'scaler' in checkpoint:\n                scaler.load_state_dict(checkpoint['scaler'])\n            logging.info(f\"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})\")\n        else:\n            # loading a bare (model only) checkpoint for fine-tune or evaluation\n            model.load_state_dict(checkpoint)\n            logging.info(f\"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})\")\n\n    # initialize datasets\n    data = get_data(\n        args,\n        (preprocess_train, preprocess_val),\n        epoch=start_epoch,\n        tokenizer=tokenizer,\n    )\n    assert len(data), 'At least one train or eval dataset must be specified.'\n\n    # create scheduler if train\n    if 'train' in data and optimizer is not None:\n        scheduler = prepare_scheduler(args, optimizer, data[\"train\"].dataloader.num_batches)\n\n    # determine if this worker should save logs and checkpoints. only do so if it is rank == 0\n    args.save_logs = args.log_dir and args.log_dir.lower() != 'none' and is_master(args)\n    writer = None\n    if args.save_logs and args.tensorboard:\n        assert tensorboard is not None, \"Please install tensorboard.\"\n        writer = tensorboard.SummaryWriter(args.tensorboard_path)\n\n    if args.wandb and is_master(args):\n        assert wandb is not None, 'Please install wandb.'\n        logging.debug('Starting wandb.')\n        args.train_sz = data[\"train\"].dataloader.num_samples\n        if args.val_data is not None:\n            args.val_sz = data[\"val\"].dataloader.num_samples\n        # you will have to configure this for your project!\n        wandb.init(\n            project=args.wandb_project_name,\n            name=args.name,\n            id=args.name,\n            notes=args.wandb_notes,\n            tags=[],\n            resume='auto' if args.resume == \"latest\" else None,\n            config=vars(args),\n        )\n        if args.debug:\n            wandb.watch(model, log='all')\n        wandb.save(params_file)\n        logging.debug('Finished loading wandb.')\n\n    # Pytorch 2.0 adds '_orig_mod.' prefix to keys of state_dict() of compiled models.\n    # For compatibility, we save state_dict() of the original model, which shares the\n    # weights without the prefix.\n    original_model = model\n    if args.torchcompile:\n        logging.info('Compiling model...')\n        model = torch.compile(original_model)\n\n    if 'train' not in data:\n        # If using int8, convert to inference mode.\n        if args.use_bnb_linear is not None:\n            from open_clip.utils import convert_int8_model_to_inference_mode\n            convert_int8_model_to_inference_mode(model)\n        # Evaluate.\n        evaluate(model, data, start_epoch, args, tb_writer=writer, tokenizer=tokenizer)\n        return\n\n    loss = create_loss(args)\n    start_timestamp = time.time()\n    for epoch in range(start_epoch, args.epochs):\n        if is_master(args):\n            logging.info(f'Start epoch {epoch}')\n\n        train_one_epoch(start_timestamp, model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer)\n        completed_epoch = epoch + 1\n\n        if any(v in data for v in ('val', 'imagenet-val', 'imagenet-v2')):\n            evaluate(model, data, completed_epoch, args, tb_writer=writer, tokenizer=tokenizer)\n\n        # Saving checkpoints.\n        if args.save_logs:\n            checkpoint_dict = {\n                \"epoch\": completed_epoch,\n                \"name\": args.name,\n                \"state_dict\": original_model.state_dict(),\n                \"optimizer\": optimizer.state_dict(),\n            }\n            if scaler is not None:\n                checkpoint_dict[\"scaler\"] = scaler.state_dict()\n\n            if completed_epoch == args.epochs or (\n                args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0\n            ):\n                torch.save(\n                    checkpoint_dict,\n                    os.path.join(args.checkpoint_path, f\"epoch_{completed_epoch}.pt\"),\n                )\n            if args.delete_previous_checkpoint:\n                previous_checkpoint = os.path.join(args.checkpoint_path, f\"epoch_{completed_epoch - 1}.pt\")\n                if os.path.exists(previous_checkpoint):\n                    os.remove(previous_checkpoint)\n\n            if args.save_most_recent:\n                # try not to corrupt the latest checkpoint if save fails\n                tmp_save_path = os.path.join(args.checkpoint_path, \"tmp.pt\")\n                latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME)\n                torch.save(checkpoint_dict, tmp_save_path)\n                os.replace(tmp_save_path, latest_save_path)\n\n    if args.wandb and is_master(args):\n        wandb.finish()\n\n    # run a final sync.\n    if remote_sync_process is not None:\n        logging.info('Final remote sync.')\n        remote_sync_process.terminate()\n        result = remote_sync(\n            os.path.join(args.log_dir, args.name), \n            os.path.join(args.remote_sync, args.name), \n            args.remote_sync_protocol\n        )\n        if result:\n            logging.info('Final remote sync successful.')\n        else:\n            logging.info('Final remote sync failed.')\n\n\nif __name__ == \"__main__\":\n    main(sys.argv[1:])\n"
  },
  {
    "path": "inf_clip/train/optims.py",
    "content": "import math\n\nimport torch\nfrom torch import nn\nfrom torch.optim import Optimizer\n\n\nclass ScalingViTAdafactor(Optimizer):\n    \"\"\"\n    Modified version of Adafactor in Transformers https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/optimization.py#L672, \n    which refers to Paper: *Scaling Vision Transformers* https://arxiv.org/pdf/2106.04560\n\n    1. Re-introducing the first momentum in half-precision.\n    2. Disable scaling of learning rate relative to weight norms, a feature that is part of Adafactor.\n    3. Clipping the second momentum at 0.999\n    \"\"\"\n\n    def __init__(\n        self,\n        params,\n        lr=None,\n        eps=(1e-30, 1e-3),\n        clip_threshold=1.0,\n        decay_rate=-0.8,\n        beta1=0.9,\n        beta2=0.999,\n        weight_decay=0.0,\n        scale_parameter=False,\n        relative_step=False,\n        warmup_init=False,\n    ):\n        if lr is not None and relative_step:\n            raise ValueError(\"Cannot combine manual `lr` and `relative_step=True` options\")\n        if warmup_init and not relative_step:\n            raise ValueError(\"`warmup_init=True` requires `relative_step=True`\")\n\n        defaults = {\n            \"lr\": lr,\n            \"eps\": eps,\n            \"clip_threshold\": clip_threshold,\n            \"decay_rate\": decay_rate,\n            \"beta1\": beta1,\n            \"beta2\": beta2,\n            \"weight_decay\": weight_decay,\n            \"scale_parameter\": scale_parameter,\n            \"relative_step\": relative_step,\n            \"warmup_init\": warmup_init,\n        }\n        super().__init__(params, defaults)\n\n    @staticmethod\n    def _get_lr(param_group, param_state):\n        rel_step_sz = param_group[\"lr\"]\n        if param_group[\"relative_step\"]:\n            min_step = 1e-6 * param_state[\"step\"] if param_group[\"warmup_init\"] else 1e-2\n            rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state[\"step\"]))\n        param_scale = 1.0\n        if param_group[\"scale_parameter\"]:\n            param_scale = max(param_group[\"eps\"][1], param_state[\"RMS\"])\n        return param_scale * rel_step_sz\n\n    @staticmethod\n    def _get_options(param_group, param_shape):\n        factored = len(param_shape) >= 2\n        use_first_moment = param_group[\"beta1\"] is not None\n        return factored, use_first_moment\n\n    @staticmethod\n    def _rms(tensor):\n        return tensor.norm(2) / (tensor.numel() ** 0.5)\n\n    @staticmethod\n    def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):\n        # copy from fairseq's adafactor implementation:\n        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505\n        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)\n        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()\n        return torch.mul(r_factor, c_factor)\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        \"\"\"\n        Performs a single optimization step\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group[\"params\"]:\n                if p.grad is None:\n                    continue\n                grad = p.grad\n                # NOTE: gradient keep in float32\n                # if grad.dtype in {torch.float16, torch.bfloat16}:\n                #     grad = grad.float()\n                if grad.is_sparse:\n                    raise RuntimeError(\"Adafactor does not support sparse gradients.\")\n\n                state = self.state[p]\n                grad_shape = grad.shape\n\n                factored, use_first_moment = self._get_options(group, grad_shape)\n                # State Initialization\n                if len(state) == 0:\n                    state[\"step\"] = 0\n\n                    if use_first_moment:\n                        # NOTE: using bfloat16 for first momentum\n                        # Exponential moving average of gradient values\n                        state[\"exp_avg\"] = torch.zeros_like(grad).bfloat16()\n                    if factored:\n                        state[\"exp_avg_sq_row\"] = torch.zeros(grad_shape[:-1]).to(grad)\n                        state[\"exp_avg_sq_col\"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)\n                    else:\n                        state[\"exp_avg_sq\"] = torch.zeros_like(grad)\n\n                    state[\"RMS\"] = 0\n                else:\n                    if use_first_moment:\n                        state[\"exp_avg\"] = state[\"exp_avg\"].to(grad)\n                    if factored:\n                        state[\"exp_avg_sq_row\"] = state[\"exp_avg_sq_row\"].to(grad)\n                        state[\"exp_avg_sq_col\"] = state[\"exp_avg_sq_col\"].to(grad)\n                    else:\n                        state[\"exp_avg_sq\"] = state[\"exp_avg_sq\"].to(grad)\n\n                p_data_fp32 = p\n                # NOTE: keep in float32\n                # if p.dtype in {torch.float16, torch.bfloat16}:\n                #     p_data_fp32 = p_data_fp32.float()\n\n                state[\"step\"] += 1\n                state[\"RMS\"] = self._rms(p_data_fp32)\n                lr = self._get_lr(group, state)\n\n                beta2t = 1.0 - math.pow(state[\"step\"], group[\"decay_rate\"])\n                beta2t = min(beta2t, group[\"beta2\"])\n                update = (grad**2) + group[\"eps\"][0]\n                if factored:\n                    exp_avg_sq_row = state[\"exp_avg_sq_row\"]\n                    exp_avg_sq_col = state[\"exp_avg_sq_col\"]\n\n                    exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))\n                    exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))\n\n                    # Approximation of exponential moving average of square of gradient\n                    update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)\n                    update.mul_(grad)\n                else:\n                    exp_avg_sq = state[\"exp_avg_sq\"]\n\n                    exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))\n                    update = exp_avg_sq.rsqrt().mul_(grad)\n\n                update.div_((self._rms(update) / group[\"clip_threshold\"]).clamp_(min=1.0))\n                update.mul_(lr)\n\n                if use_first_moment:\n                    exp_avg = state[\"exp_avg\"]\n                    exp_avg.mul_(group[\"beta1\"]).add_(update, alpha=(1 - group[\"beta1\"]))\n                    update = exp_avg\n\n                if group[\"weight_decay\"] != 0:\n                    p_data_fp32.add_(p_data_fp32, alpha=(-group[\"weight_decay\"] * lr))\n\n                p_data_fp32.add_(-update)\n\n                # if p.dtype in {torch.float16, torch.bfloat16}:\n                #     p.copy_(p_data_fp32)\n\n        return loss\n\n\nclass Lion(Optimizer):\n    \"\"\"\n    Modified version of Lion in https://github.com/google/automl/blob/master/lion/lion_pytorch.py, \n    which refers to Paper: *Symbolic Discovery of Optimization Algorithms* https://arxiv.org/pdf/2302.06675\n    \"\"\"\n    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):\n        \"\"\"Initialize the hyperparameters.\n\n        Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-4)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.99))\n        weight_decay (float, optional): weight decay coefficient (default: 0)\n        \"\"\"\n\n        if not 0.0 <= lr:\n            raise ValueError('Invalid learning rate: {}'.format(lr))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)\n        super().__init__(params, defaults)\n\n    @torch.no_grad()\n    def step(self, closure=None):\n        loss = None\n        if closure is not None:\n            with torch.enable_grad():\n                loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n\n                # Perform stepweight decay\n                p.data.mul_(1 - group['lr'] * group['weight_decay'])\n\n                grad = p.grad\n                state = self.state[p]\n                # State initialization\n                if len(state) == 0:\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p)\n\n                exp_avg = state['exp_avg']\n                beta1, beta2 = group['betas']\n\n                # Weight update\n                update = exp_avg * beta1 + grad * (1 - beta1)\n\n                p.add_(update.sign_(), alpha=-group['lr'])\n\n                # Decay the momentum running average coefficient\n                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)\n\n        return loss\n"
  },
  {
    "path": "inf_clip/train/params.py",
    "content": "import os\nimport argparse\nimport ast\nimport json\n\nfrom .utils import world_info_from_env\n\n\ndef get_default_params(model_name):\n    # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)\n    model_name = model_name.lower()\n    if \"vit\" in model_name:\n        return {\"lr\": 5.0e-4, \"beta1\": 0.9, \"beta2\": 0.98, \"eps\": 1.0e-6}\n    else:\n        return {\"lr\": 5.0e-4, \"beta1\": 0.9, \"beta2\": 0.999, \"eps\": 1.0e-8}\n\n\nclass ParseKwargs(argparse.Action):\n    def __call__(self, parser, namespace, values, option_string=None):\n        kw = {}\n        for value in values:\n            key, value = value.split('=')\n            try:\n                kw[key] = ast.literal_eval(value)\n            except ValueError:\n                kw[key] = str(value)  # fallback to string (avoid need to escape on command line)\n        setattr(namespace, self.dest, kw)\n\n\ndef parse_args(args):\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--train-data\",\n        type=str,\n        default=None,\n        help=\"Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.\",\n    )\n    parser.add_argument(\n        \"--train-data-upsampling-factors\",\n        type=str,\n        default=None,\n        help=(\n            \"When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. \"\n            \"Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) \"\n            \"By default, datapoints are sampled uniformly regardless of the dataset sizes.\"\n        )\n    )\n    parser.add_argument(\n        \"--val-data\",\n        type=str,\n        default=None,\n        help=\"Path to file(s) with validation data\",\n    )\n    parser.add_argument(\n        \"--train-num-samples\",\n        type=int,\n        default=None,\n        help=\"Number of samples in dataset. Required for webdataset if not available in info file.\",\n    )\n    parser.add_argument(\n        \"--val-num-samples\",\n        type=int,\n        default=None,\n        help=\"Number of samples in dataset. Useful for webdataset if not available in info file.\",\n    )\n    parser.add_argument(\n        \"--dataset-type\",\n        choices=[\"webdataset\", \"csv\", \"synthetic\", \"auto\"],\n        default=\"auto\",\n        help=\"Which type of dataset to process.\"\n    )\n    parser.add_argument(\n        \"--dataset-resampled\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use sampling with replacement for webdataset shard selection.\"\n    )\n    parser.add_argument(\n        \"--csv-separator\",\n        type=str,\n        default=\"\\t\",\n        help=\"For csv-like datasets, which separator to use.\"\n    )\n    parser.add_argument(\n        \"--csv-img-key\",\n        type=str,\n        default=\"filepath\",\n        help=\"For csv-like datasets, the name of the key for the image paths.\"\n    )\n    parser.add_argument(\n        \"--csv-caption-key\",\n        type=str,\n        default=\"title\",\n        help=\"For csv-like datasets, the name of the key for the captions.\"\n    )\n    parser.add_argument(\n        \"--imagenet-val\",\n        type=str,\n        default=None,\n        help=\"Path to imagenet val set for conducting zero shot evaluation.\",\n    )\n    parser.add_argument(\n        \"--imagenet-v2\",\n        type=str,\n        default=None,\n        help=\"Path to imagenet v2 for conducting zero shot evaluation.\",\n    )\n    parser.add_argument(\n        \"--log_dir\",\n        type=str,\n        default=\"./logs/\",\n        help=\"Where to store tensorboard logs. Use None to avoid storing logs.\",\n    )\n    parser.add_argument(\n        \"--log-local\",\n        action=\"store_true\",\n        default=False,\n        help=\"log files on local master, otherwise global master only.\",\n    )\n    parser.add_argument(\n        \"--name\",\n        type=str,\n        default=None,\n        help=\"Optional identifier for the experiment when storing logs. Otherwise use current time.\",\n    )\n    parser.add_argument(\n        \"--workers\", type=int, default=4, help=\"Number of dataloader workers per GPU.\"\n    )\n    parser.add_argument(\n        \"--batch-size\", type=int, default=64, help=\"Batch size per GPU.\"\n    )\n    parser.add_argument(\n        \"--epochs\", type=int, default=32, help=\"Number of epochs to train for.\"\n    )\n    parser.add_argument(\n        \"--epochs-cooldown\", type=int, default=None,\n        help=\"When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards.\"\n    )\n    parser.add_argument(\"--optimizer\", type=str, default=\"adam\", help=\"Optimizer to use.\")\n    parser.add_argument(\"--lr\", type=float, default=None, help=\"Learning rate.\")\n    parser.add_argument(\"--beta1\", type=float, default=None, help=\"coefficient of moving average of first moment.\")\n    parser.add_argument(\"--beta2\", type=float, default=None, help=\"coefficient of moving average of second moment.\")\n    parser.add_argument(\"--eps\", type=float, default=None, help=\"Adam epsilon.\")\n    parser.add_argument(\"--wd\", type=float, default=0.2, help=\"Weight decay.\")\n    parser.add_argument(\n        \"--warmup\", type=int, default=10000, help=\"Number of steps to warmup for.\"\n    )\n    parser.add_argument(\n        \"--use-bn-sync\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to use batch norm sync.\")\n    parser.add_argument(\n        \"--skip-scheduler\",\n        action=\"store_true\",\n        default=False,\n        help=\"Use this flag to skip the learning rate decay.\",\n    )\n    parser.add_argument(\n        \"--lr-scheduler\",\n        type=str,\n        default='cosine',\n        help=\"LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine\",\n    )\n    parser.add_argument(\n        \"--lr-cooldown-end\", type=float, default=0.0,\n        help=\"End learning rate for cooldown schedule. Default: 0\"\n    )\n    parser.add_argument(\n        \"--lr-cooldown-power\", type=float, default=1.0,\n        help=\"Power for polynomial cooldown schedule. Default: 1.0 (linear decay)\"\n    )\n    parser.add_argument(\n        \"--save-frequency\", type=int, default=1, help=\"How often to save checkpoints.\"\n    )\n    parser.add_argument(\n        \"--save-most-recent\",\n        action=\"store_true\",\n        default=False,\n        help=\"Always save the most recent model trained to epoch_latest.pt.\",\n    )\n    parser.add_argument(\n        \"--zeroshot-frequency\", type=int, default=2, help=\"How often to run zero shot.\"\n    )\n    parser.add_argument(\n        \"--val-frequency\", type=int, default=1, help=\"How often to run evaluation with val data.\"\n    )\n    parser.add_argument(\n        \"--resume\",\n        default=None,\n        type=str,\n        help=\"path to latest checkpoint (default: none)\",\n    )\n    parser.add_argument(\n        \"--precision\",\n        choices=[\"amp\", \"amp_bf16\", \"amp_bfloat16\", \"bf16\", \"fp16\", \"pure_bf16\", \"pure_fp16\", \"fp32\"],\n        default=\"amp\",\n        help=\"Floating point precision.\"\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        default=\"RN50\",\n        help=\"Name of the vision backbone to use.\",\n    )\n    parser.add_argument(\n        \"--pretrained\",\n        default='',\n        type=str,\n        help=\"Use a pretrained CLIP model weights with the specified tag or file path.\",\n    )\n    parser.add_argument(\n        \"--pretrained-image\",\n        default=False,\n        action='store_true',\n        help=\"Load imagenet pretrained weights for image tower backbone if available.\",\n    )\n    parser.add_argument(\n        \"--lock-image\",\n        default=False,\n        action='store_true',\n        help=\"Lock full image tower by disabling gradients.\",\n    )\n    parser.add_argument(\n        \"--lock-image-unlocked-groups\",\n        type=int,\n        default=0,\n        help=\"Leave last n image tower layer groups unlocked.\",\n    )\n    parser.add_argument(\n        \"--lock-image-freeze-bn-stats\",\n        default=False,\n        action='store_true',\n        help=\"Freeze BatchNorm running stats in image tower for any locked layers.\",\n    )\n    parser.add_argument(\n        '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',\n        help='Override default image mean value of dataset')\n    parser.add_argument(\n        '--image-std', type=float, nargs='+', default=None, metavar='STD',\n        help='Override default image std deviation of dataset')\n    parser.add_argument(\n        '--image-interpolation',\n        default=None, type=str, choices=['bicubic', 'bilinear', 'random'],\n        help=\"Override default image resize interpolation\"\n    )\n    parser.add_argument(\n        '--image-resize-mode',\n        default=None, type=str, choices=['shortest', 'longest', 'squash'],\n        help=\"Override default image resize (& crop) mode during inference\"\n    )\n    parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)\n    parser.add_argument(\n        \"--grad-checkpointing\",\n        default=False,\n        action='store_true',\n        help=\"Enable gradient checkpointing.\",\n    )\n    parser.add_argument(\n        \"--local-loss\",\n        default=False,\n        action=\"store_true\",\n        help=\"calculate loss w/ local features @ global (instead of realizing full global @ global matrix)\"\n    )\n    parser.add_argument(\n        \"--gather-with-grad\",\n        default=False,\n        action=\"store_true\",\n        help=\"enable full distributed gradient for feature gather\"\n    )\n    parser.add_argument(\n        '--force-image-size', type=int, nargs='+', default=None,\n        help='Override default image size'\n    )\n    parser.add_argument(\n        \"--force-quick-gelu\",\n        default=False,\n        action='store_true',\n        help=\"Force use of QuickGELU activation for non-OpenAI transformer models.\",\n    )\n    parser.add_argument(\n        \"--force-patch-dropout\",\n        default=None,\n        type=float,\n        help=\"Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper\",\n    )\n    parser.add_argument(\n        \"--force-custom-text\",\n        default=False,\n        action='store_true',\n        help=\"Force use of CustomTextCLIP model (separate text-tower).\",\n    )\n    parser.add_argument(\n        \"--torchscript\",\n        default=False,\n        action='store_true',\n        help=\"torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'\",\n    )\n    parser.add_argument(\n        \"--torchcompile\",\n        default=False,\n        action='store_true',\n        help=\"torch.compile() the model, requires pytorch 2.0 or later.\",\n    )\n    parser.add_argument(\n        \"--trace\",\n        default=False,\n        action='store_true',\n        help=\"torch.jit.trace the model for inference / eval only\",\n    )\n    parser.add_argument(\n        \"--accum-freq\", type=int, default=1, help=\"Update the model every --acum-freq steps.\"\n    )\n    # arguments for distributed training\n    parser.add_argument(\n        \"--dist-url\",\n        default=\"env://\",\n        type=str,\n        help=\"url used to set up distributed training\",\n    )\n    parser.add_argument(\n        \"--dist-backend\", default=\"nccl\", type=str, help=\"distributed backend\"\n    )\n    parser.add_argument(\n        \"--report-to\",\n        default='',\n        type=str,\n        help=\"Options are ['wandb', 'tensorboard', 'wandb,tensorboard']\"\n    )\n    parser.add_argument(\n        \"--wandb-notes\",\n        default='',\n        type=str,\n        help=\"Notes if logging with wandb\"\n    )\n    parser.add_argument(\n        \"--wandb-project-name\",\n        type=str,\n        default='open-clip',\n        help=\"Name of the project if logging with wandb.\",\n    )\n    parser.add_argument(\n        \"--debug\",\n        default=False,\n        action=\"store_true\",\n        help=\"If true, more information is logged.\"\n    )\n    parser.add_argument(\n        \"--copy-codebase\",\n        default=False,\n        action=\"store_true\",\n        help=\"If true, we copy the entire base on the log directory, and execute from there.\"\n    )\n    parser.add_argument(\n        \"--horovod\",\n        default=False,\n        action=\"store_true\",\n        help=\"Use horovod for distributed training.\"\n    )\n    parser.add_argument(\n        \"--ddp-static-graph\",\n        default=False,\n        action='store_true',\n        help=\"Enable static graph optimization for DDP in PyTorch >= 1.11.\",\n    )\n    parser.add_argument(\n        \"--no-set-device-rank\",\n        default=False,\n        action=\"store_true\",\n        help=\"Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).\"\n    )\n    parser.add_argument(\n        \"--seed\", type=int, default=0, help=\"Default random seed.\"\n    )\n    parser.add_argument(\n        \"--grad-clip-norm\", type=float, default=None, help=\"Gradient clip.\"\n    )\n    parser.add_argument(\n        \"--lock-text\",\n        default=False,\n        action='store_true',\n        help=\"Lock full text tower by disabling gradients.\",\n    )\n    parser.add_argument(\n        \"--lock-text-unlocked-layers\",\n        type=int,\n        default=0,\n        help=\"Leave last n text tower layer groups unlocked.\",\n    )\n    parser.add_argument(\n        \"--lock-text-freeze-layer-norm\",\n        default=False,\n        action='store_true',\n        help=\"Freeze LayerNorm running stats in text tower for any locked layers.\",\n    )\n    parser.add_argument(\n        \"--log-every-n-steps\",\n        type=int,\n        default=100,\n        help=\"Log every n steps to tensorboard/console/wandb.\",\n    )\n    parser.add_argument(\n        \"--coca-caption-loss-weight\",\n        type=float,\n        default=2.0,\n        help=\"Weight assigned to caption loss in CoCa.\"\n    )\n    parser.add_argument(\n        \"--coca-contrastive-loss-weight\",\n        type=float,\n        default=1.0,\n        help=\"Weight assigned to contrastive loss when training CoCa.\"\n    )\n    parser.add_argument(\n        \"--remote-sync\",\n        type=str,\n        default=None,\n        help=\"Optinoally sync with a remote path specified by this arg\",\n    )\n    parser.add_argument(\n        \"--remote-sync-frequency\",\n        type=int,\n        default=300,\n        help=\"How frequently to sync to a remote directly if --remote-sync is not None.\",\n    )\n    parser.add_argument(\n        \"--remote-sync-protocol\",\n        choices=[\"s3\", \"fsspec\"],\n        default=\"s3\",\n        help=\"How to do the remote sync backup if --remote-sync is not None.\",\n    )\n    parser.add_argument(\n        \"--delete-previous-checkpoint\",\n        default=False,\n        action=\"store_true\",\n        help=\"If true, delete previous checkpoint after storing a new one.\"\n    )\n    parser.add_argument(\n        \"--distill-model\",\n        default=None,\n        help='Which model arch to distill from, if any.'\n    )\n    parser.add_argument(\n        \"--distill-pretrained\",\n        default=None,\n        help='Which pre-trained weights to distill from, if any.'\n    )\n    parser.add_argument(\n        \"--use-bnb-linear\",\n        default=None,\n        help='Replace the network linear layers from the bitsandbytes library. '\n        'Allows int8 training/inference, etc.'\n    )\n    parser.add_argument(\n        \"--siglip\",\n        default=False,\n        action=\"store_true\",\n        help='Use SigLip (sigmoid) loss.'\n    )\n    parser.add_argument(\n        \"--flashloss\",\n        default=False,\n        action=\"store_true\",\n        help='Use flash loss.'\n    )\n    parser.add_argument(\n        \"--ringloss\",\n        default=False,\n        action=\"store_true\",\n        help='Use ring loss.'\n    )\n    parser.add_argument(\n        \"--infloss\",\n        default=False,\n        action=\"store_true\",\n        help='Use ring flash loss.'\n    )\n    parser.add_argument(\n        \"--discoloss\",\n        default=False,\n        action=\"store_true\",\n        help='Use disc loss.'\n    )\n\n    try:\n        import deepspeed\n        parser = deepspeed.add_config_arguments(parser)\n        parser.add_argument('--zero-stage', type=int, default=1, help='stage of ZERO')\n    except:\n        print(\"Please 'pip install deepspeed==0.8.1'\")\n        exit(0)\n\n    args = parser.parse_args(args)\n\n    if args.deepspeed:\n        create_deepspeed_config(args)\n\n    # If some params are not passed, we use the default values based on model name.\n    default_params = get_default_params(args.model)\n    for name, val in default_params.items():\n        if getattr(args, name) is None:\n            setattr(args, name, val)\n\n    return args\n\n\ndef create_deepspeed_config(args):\n    _, _, world_size = world_info_from_env()\n    args.deepspeed_config = os.path.join(os.getcwd(), \"scripts\", \"deepspeed_config.json\")\n    # default optimizer\n    optim_settings = None\n    if args.optimizer.lower() == \"adamw\":\n        optim_settings = {\n            \"type\": \"Adam\",\n            \"adam_w_mode\": True,\n            \"params\": {\n                \"bias_correction\": True,\n                \"betas\": [\n                    args.beta1,\n                    args.beta2\n                ],\n                \"eps\": args.eps,\n            }\n        }\n    # LAMB\n    elif args.optimizer.lower() == \"lamb\":\n        # https://arxiv.org/pdf/1904.00962.pdf\n        optim_settings = {\n            \"type\": \"LAMB\",\n            \"params\": {\n            \"bias_correction\": True,\n            \"betas\": [\n                args.beta1,\n                args.beta2\n            ],\n            \"eps\": args.eps,\n            \"max_coeff\": 10.0, #0.3\n            \"min_coeff\": 0.01,\n            \"eps_inside_sqrt\": False,\n            }\n        }\n    if args.optimizer.lower() == \"1bitlamb\":\n        # not supported\n        # 1bit-Lamb is not compatible with ZeRO; zero-stage should be 0\n        # https://arxiv.org/abs/2104.06069\n        optim_settings = {\n            \"type\": \"OneBitLamb\",\n            \"params\": {\n            \"bias_correction\": True,\n            \"betas\": [\n                args.beta1,\n                args.beta2\n            ],\n            \"eps\": args.eps,\n            \"max_coeff\": 10.0, #0.3\n            \"min_coeff\": 0.01,\n            \"eps_inside_sqrt\": False,\n            \"freeze_step\": args.warmup,\n            # \"comm_backend_name\": \"nccl\",\n            # \"coeff_beta\": 0.9,\n            # \"factor_max\": 4.0,\n            # \"factor_min\": 0.5,\n            # \"factor_threshold\": 0.1\n            }\n        }\n\n    with open(args.deepspeed_config, mode=\"w\") as writer:\n        ds_config = {\n            \"train_batch_size\": args.batch_size * world_size * args.accum_freq,\n            \"train_micro_batch_size_per_gpu\": args.batch_size,\n            \"gradient_accumulation_steps\": args.accum_freq,\n            \"gradient_accumulation_dtype\": \"fp32\",\n            \"steps_per_print\": 1000000,\n            \"zero_allow_untested_optimizer\": True,\n            \"fp16\": {\n                \"enabled\": True if args.precision != \"bf16\" else False,\n                # \"auto_cast\": True,\n                \"loss_scale\": 0,\n                \"initial_scale_power\": 0,\n                \"loss_scale_window\": 1000,\n                \"hysteresis\": 2,\n                \"min_loss_scale\": 1\n            },\n            \"bf16\": {\n                \"enabled\": args.precision == \"bf16\"\n            },\n            \"amp\": {\n                \"enabled\": False,\n                \"opt_level\": \"O2\"\n            },\n            \"flops_profiler\": {\n                \"enabled\": True,\n                \"profile_step\": -1,\n                \"module_depth\": -1,\n                \"top_modules\": 1,\n                \"detailed\": True,\n            },\n            \"activation_checkpointing\": {\n                \"partition_activations\": args.grad_checkpointing,\n                \"contiguous_memory_optimization\": False,\n                \"profile\": True\n            },\n            # \"wallclock_breakdown\": True\n        }\n\n        if optim_settings is not None:\n            ds_config.update({'optimizer': optim_settings})\n\n        if args.grad_clip_norm is not None:\n            ds_config.update({'gradient_clipping': args.grad_clip_norm})\n\n        if args.zero_stage == 1:\n            ds_config.update(\n                {\n                    \"zero_optimization\": {\n                        \"stage\": 1, \n                        \"reduce_bucket_size\": 5e8,\n                    }\n                }\n            )\n        elif args.zero_stage == 2:\n            ds_config.update(\n                {\n                    \"zero_optimization\": {\n                    \"stage\": 2,\n                    \"contiguous_gradients\": ('vit-b' not in args.model.lower()), # should be False if model is small,\n                    \"overlap_comm\": True,\n                    \"reduce_scatter\": True,\n                    \"reduce_bucket_size\": 5e8,\n                    \"allgather_bucket_size\": 5e8,\n                    \"cpu_offload\": False \n                    }\n                }\n            )\n        elif args.zero_stage == 3:\n            ds_config.update(\n                {\n                    \"zero_optimization\": {\n                        \"stage\": 3,\n                        \"contiguous_gradients\": True,\n                        \"overlap_comm\": True,\n                        \"reduce_scatter\": True,\n                        \"reduce_bucket_size\": 5e4,\n                        \"allgather_bucket_size\": 5e4,\n                        \"cpu_offload\": False,\n                    },\n                    \"stage3_max_live_parameters\": 1e5,\n                    \"stage3_max_reuse_distance\": 1e5,\n                }\n            )\n        elif args.zero_stage > 3:\n            raise NotImplementedError()\n\n        writer.write(json.dumps(ds_config, indent=2))\n"
  },
  {
    "path": "inf_clip/train/utils.py",
    "content": "import os\nimport time\nimport logging\nimport subprocess\nimport multiprocessing\n\nimport torch\nimport torch.distributed as dist\nimport fsspec\nfrom tqdm import tqdm\nfrom contextlib import suppress\n\ntry:\n    import horovod.torch as hvd\nexcept ImportError:\n    hvd = None\n\n\ndef setup_logging(log_file, level, include_host=False):\n    if include_host:\n        import socket\n        hostname = socket.gethostname()\n        formatter = logging.Formatter(\n            f'%(asctime)s |  {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')\n    else:\n        formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S')\n\n    logging.root.setLevel(level)\n    loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]\n    for logger in loggers:\n        logger.setLevel(level)\n\n    stream_handler = logging.StreamHandler()\n    stream_handler.setFormatter(formatter)\n    logging.root.addHandler(stream_handler)\n\n    if log_file:\n        file_handler = logging.FileHandler(filename=log_file)\n        file_handler.setFormatter(formatter)\n        logging.root.addHandler(file_handler)\n\n\ndef remote_sync_s3(local_dir, remote_dir):\n    # skip epoch_latest which can change during sync.\n    result = subprocess.run([\"aws\", \"s3\", \"sync\", local_dir, remote_dir, '--exclude', '*epoch_latest.pt'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n    if result.returncode != 0:\n        logging.error(f\"Error: Failed to sync with S3 bucket {result.stderr.decode('utf-8')}\")\n        return False\n        \n    logging.info(f\"Successfully synced with S3 bucket\")\n    return True\n\n\ndef remote_sync_fsspec(local_dir, remote_dir):\n    # FIXME currently this is slow and not recommended. Look into speeding up.\n    a = fsspec.get_mapper(local_dir)\n    b = fsspec.get_mapper(remote_dir)\n\n    for k in a:\n        # skip epoch_latest which can change during sync.\n        if 'epoch_latest.pt' in k:\n            continue\n\n        logging.info(f'Attempting to sync {k}')\n        if k in b and len(a[k]) == len(b[k]):\n            logging.debug(f'Skipping remote sync for {k}.')\n            continue\n\n        try:\n            logging.info(f'Successful sync for {k}.')\n            b[k] = a[k]\n        except Exception as e:\n            logging.info(f'Error during remote sync for {k}: {e}')\n            return False\n\n    return True\n\n\ndef remote_sync(local_dir, remote_dir, protocol):\n    logging.info('Starting remote sync.')\n    if protocol == 's3':\n        return remote_sync_s3(local_dir, remote_dir)\n    elif protocol == 'fsspec':\n        return remote_sync_fsspec(local_dir, remote_dir)\n    else:\n        logging.error('Remote protocol not known')\n        return False\n\n\ndef keep_running_remote_sync(sync_every, local_dir, remote_dir, protocol):\n    while True:\n        time.sleep(sync_every)\n        remote_sync(local_dir, remote_dir, protocol)\n\n\ndef start_sync_process(sync_every, local_dir, remote_dir, protocol):\n    p = multiprocessing.Process(target=keep_running_remote_sync, args=(sync_every, local_dir, remote_dir, protocol))\n    return p\n\n\n# Note: we are not currently using this save function.\ndef pt_save(pt_obj, file_path):\n    of = fsspec.open(file_path, \"wb\")\n    with of as f:\n        torch.save(pt_obj, file_path)\n\n\ndef pt_load(file_path, map_location=None):\n    if file_path.startswith('s3'):\n        logging.info('Loading remote checkpoint, which may take a bit.')\n    of = fsspec.open(file_path, \"rb\")\n    with of as f:\n        out = torch.load(f, map_location=map_location)\n    return out\n\n\ndef check_exists(file_path):\n    try:\n        with fsspec.open(file_path):\n            pass\n    except FileNotFoundError:\n        return False\n    return True\n\n\ndef get_autocast(precision):\n    if precision == 'amp':\n        return lambda: torch.amp.autocast(\"cuda\", dtype=torch.bfloat16)\n    elif precision == 'amp_bfloat16' or precision == 'amp_bf16':\n        # amp_bfloat16 is more stable than amp float16 for clip training\n        return lambda: torch.amp.autocast(\"cuda\", dtype=torch.bfloat16)\n    else:\n        return suppress\n\n\n##################################\n# Distributed training utilities #\n##################################\n\ndef is_global_master(args):\n    return args.rank == 0\n\n\ndef is_local_master(args):\n    return args.local_rank == 0\n\n\ndef is_master(args, local=False):\n    return is_local_master(args) if local else is_global_master(args)\n\n\ndef is_using_horovod():\n    # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set\n    # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...\n    ompi_vars = [\"OMPI_COMM_WORLD_RANK\", \"OMPI_COMM_WORLD_SIZE\"]\n    pmi_vars = [\"PMI_RANK\", \"PMI_SIZE\"]\n    if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):\n        return True\n    else:\n        return False\n\n\ndef is_using_distributed():\n    if 'WORLD_SIZE' in os.environ:\n        return int(os.environ['WORLD_SIZE']) > 1\n    if 'SLURM_NTASKS' in os.environ:\n        return int(os.environ['SLURM_NTASKS']) > 1\n    return False\n\n\ndef world_info_from_env():\n    local_rank = 0\n    for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'):\n        if v in os.environ:\n            local_rank = int(os.environ[v])\n            break\n    global_rank = 0\n    for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'):\n        if v in os.environ:\n            global_rank = int(os.environ[v])\n            break\n    world_size = 1\n    for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'):\n        if v in os.environ:\n            world_size = int(os.environ[v])\n            break\n\n    return local_rank, global_rank, world_size\n\n\ndef init_distributed_device(args):\n    # Distributed training = training on more than one GPU.\n    # Works in both single and multi-node scenarios.\n    args.distributed = False\n    args.world_size = 1\n    args.rank = 0  # global rank\n    args.local_rank = 0\n    if args.horovod:\n        assert hvd is not None, \"Horovod is not installed\"\n        hvd.init()\n        args.local_rank = int(hvd.local_rank())\n        args.rank = hvd.rank()\n        args.world_size = hvd.size()\n        args.distributed = True\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n        os.environ['RANK'] = str(args.rank)\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n    elif is_using_distributed():\n        if 'SLURM_PROCID' in os.environ:\n            # DDP via SLURM\n            args.local_rank, args.rank, args.world_size = world_info_from_env()\n            # SLURM var -> torch.distributed vars in case needed\n            os.environ['LOCAL_RANK'] = str(args.local_rank)\n            os.environ['RANK'] = str(args.rank)\n            os.environ['WORLD_SIZE'] = str(args.world_size)\n            torch.distributed.init_process_group(\n                backend=args.dist_backend,\n                init_method=args.dist_url,\n                world_size=args.world_size,\n                rank=args.rank,\n            )\n        else:\n            # DDP via torchrun, torch.distributed.launch\n            args.local_rank, _, _ = world_info_from_env()\n            torch.distributed.init_process_group(\n                backend=args.dist_backend,\n                init_method=args.dist_url)\n            args.world_size = torch.distributed.get_world_size()\n            args.rank = torch.distributed.get_rank()\n        args.distributed = True\n    else:\n        # DDP via torchrun, torch.distributed.launch\n        args.local_rank, _, _ = world_info_from_env()\n        torch.distributed.init_process_group(\n            backend=args.dist_backend,\n            init_method=args.dist_url)\n        args.world_size = torch.distributed.get_world_size()\n        args.rank = torch.distributed.get_rank()\n        args.distributed = True\n\n    if torch.cuda.is_available():\n        if args.distributed and not args.no_set_device_rank:\n            device = 'cuda:%d' % args.local_rank\n        else:\n            device = 'cuda:0'\n        torch.cuda.set_device(device)\n    else:\n        device = 'cpu'\n    args.device = device\n    device = torch.device(device)\n    return device\n\n\ndef broadcast_object(args, obj, src=0):\n    # broadcast a pickle-able python object from rank-0 to all ranks\n    if args.horovod:\n        return hvd.broadcast_object(obj, root_rank=src)\n    else:\n        if args.rank == src:\n            objects = [obj]\n        else:\n            objects = [None]\n        dist.broadcast_object_list(objects, src=src)\n        return objects[0]\n\n\ndef all_gather_object(args, obj, dst=0):\n    # gather a pickle-able python object across all ranks\n    if args.horovod:\n        return hvd.allgather_object(obj)\n    else:\n        objects = [None for _ in range(args.world_size)]\n        dist.all_gather_object(objects, obj)\n        return objects\n"
  },
  {
    "path": "inf_clip/utils.py",
    "content": "from itertools import repeat\nimport collections.abc\n\nimport torch\nfrom torch import nn as nn\nfrom torchvision.ops.misc import FrozenBatchNorm2d\n\n\ndef freeze_batch_norm_2d(module, module_match={}, name=''):\n    \"\"\"\n    Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\n    itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\n    returned. Otherwise, the module is walked recursively and submodules are converted in place.\n\n    Args:\n        module (torch.nn.Module): Any PyTorch module.\n        module_match (dict): Dictionary of full module names to freeze (all if empty)\n        name (str): Full module name (prefix)\n\n    Returns:\n        torch.nn.Module: Resulting module\n\n    Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\n    \"\"\"\n    res = module\n    is_match = True\n    if module_match:\n        is_match = name in module_match\n    if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):\n        res = FrozenBatchNorm2d(module.num_features)\n        res.num_features = module.num_features\n        res.affine = module.affine\n        if module.affine:\n            res.weight.data = module.weight.data.clone().detach()\n            res.bias.data = module.bias.data.clone().detach()\n        res.running_mean.data = module.running_mean.data\n        res.running_var.data = module.running_var.data\n        res.eps = module.eps\n    else:\n        for child_name, child in module.named_children():\n            full_child_name = '.'.join([name, child_name]) if name else child_name\n            new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\n            if new_child is not child:\n                res.add_module(child_name, new_child)\n    return res\n\n\n# From PyTorch internals\ndef _ntuple(n):\n    def parse(x):\n        if isinstance(x, collections.abc.Iterable):\n            return x\n        return tuple(repeat(x, n))\n    return parse\n\n\nto_1tuple = _ntuple(1)\nto_2tuple = _ntuple(2)\nto_3tuple = _ntuple(3)\nto_4tuple = _ntuple(4)\nto_ntuple = lambda n, x: _ntuple(n)(x)\n\n# Replaces all linear layers with linear_replacement\n# TODO: add int8 support for other linear layers including attn and convnets\ndef replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):\n    for name, module in model.named_children():\n        if len(list(module.children())) > 0:\n            replace_linear(module, linear_replacement, include_modules, copy_weights)\n\n        if isinstance(module, torch.nn.Linear) and name in include_modules:\n            old_module = model._modules[name]\n            model._modules[name] = linear_replacement(\n                module.in_features,\n                module.out_features,\n                module.bias is not None,\n            )\n            if copy_weights:\n                model._modules[name].weight.data.copy_(old_module.weight.data)\n                if model._modules[name].bias is not None:\n                    model._modules[name].bias.data.copy_(old_module.bias)\n\n    return model\n\ndef convert_int8_model_to_inference_mode(model):\n    for m in model.modules():\n        if hasattr(m, 'prepare_for_eval'):\n            int8_original_dtype = m.weight.dtype\n            m.prepare_for_eval()\n            m.int8_original_dtype = int8_original_dtype"
  },
  {
    "path": "inf_clip/zero_shot_classifier.py",
    "content": "from functools import partial\nfrom itertools import islice\nfrom typing import Callable, List, Optional, Sequence, Union\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef batched(iterable, n):\n    \"\"\"Batch data into lists of length *n*. The last batch may be shorter.\n    NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl\n    \"\"\"\n    it = iter(iterable)\n    while True:\n        batch = list(islice(it, n))\n        if not batch:\n            break\n        yield batch\n\n\ndef build_zero_shot_classifier(\n        model,\n        tokenizer,\n        classnames: Sequence[str],\n        templates: Sequence[Union[Callable, str]],\n        num_classes_per_batch: Optional[int] = 10,\n        device: Union[str, torch.device] = 'cpu',\n        use_tqdm: bool = False,\n):\n    \"\"\" Build zero-shot classifier weights by iterating over class names in batches\n    Args:\n        model: CLIP model instance\n        tokenizer: CLIP tokenizer instance\n        classnames: A sequence of class (label) names\n        templates: A sequence of callables or format() friendly strings to produce templates per class name\n        num_classes_per_batch: The number of classes to batch together in each forward, all if None\n        device: Device to use.\n        use_tqdm: Enable TQDM progress bar.\n    \"\"\"\n    assert isinstance(templates, Sequence) and len(templates) > 0\n    assert isinstance(classnames, Sequence) and len(classnames) > 0\n    use_format = isinstance(templates[0], str)\n    num_templates = len(templates)\n    num_classes = len(classnames)\n    if use_tqdm:\n        import tqdm\n        num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1)\n        iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch)\n    else:\n        iter_wrap = iter\n\n    def _process_batch(batch_classnames):\n        num_batch_classes = len(batch_classnames)\n        texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]\n        texts = tokenizer(texts).to(device)\n        class_embeddings = model.encode_text(texts, normalize=True)\n        class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)\n        class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)\n        class_embeddings = class_embeddings.T\n        return class_embeddings\n\n    with torch.no_grad():\n        if num_classes_per_batch:\n            batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))]\n            zeroshot_weights = torch.cat(batched_embeds, dim=1)\n        else:\n            zeroshot_weights = _process_batch(classnames)\n    return zeroshot_weights\n\n\ndef build_zero_shot_classifier_legacy(\n        model,\n        tokenizer,\n        classnames: Sequence[str],\n        templates: Sequence[Union[Callable, str]],\n        device: Union[str, torch.device] = 'cpu',\n        use_tqdm: bool = False,\n):\n    \"\"\" Build zero-shot classifier weights by iterating over class names 1 by 1\n    Args:\n        model: CLIP model instance\n        tokenizer: CLIP tokenizer instance\n        classnames: A sequence of class (label) names\n        templates: A sequence of callables or format() friendly strings to produce templates per class name\n        device: Device to use.\n        use_tqdm: Enable TQDM progress bar.\n    \"\"\"\n    assert isinstance(templates, Sequence) and len(templates) > 0\n    assert isinstance(classnames, Sequence) and len(classnames) > 0\n    if use_tqdm:\n        import tqdm\n        iter_wrap = tqdm.tqdm\n    else:\n        iter_wrap = iter\n\n    use_format = isinstance(templates[0], str)\n\n    with torch.no_grad():\n        zeroshot_weights = []\n        for classname in iter_wrap(classnames):\n            texts = [template.format(classname) if use_format else template(classname) for template in templates]\n            texts = tokenizer(texts).to(device)  # tokenize\n            class_embeddings = model.encode_text(texts)\n            class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)\n            class_embedding /= class_embedding.norm()\n            zeroshot_weights.append(class_embedding)\n        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)\n\n    return zeroshot_weights\n\n"
  },
  {
    "path": "inf_clip/zero_shot_metadata.py",
    "content": "\nOPENAI_IMAGENET_TEMPLATES = (\n    lambda c: f'a bad photo of a {c}.',\n    lambda c: f'a photo of many {c}.',\n    lambda c: f'a sculpture of a {c}.',\n    lambda c: f'a photo of the hard to see {c}.',\n    lambda c: f'a low resolution photo of the {c}.',\n    lambda c: f'a rendering of a {c}.',\n    lambda c: f'graffiti of a {c}.',\n    lambda c: f'a bad photo of the {c}.',\n    lambda c: f'a cropped photo of the {c}.',\n    lambda c: f'a tattoo of a {c}.',\n    lambda c: f'the embroidered {c}.',\n    lambda c: f'a photo of a hard to see {c}.',\n    lambda c: f'a bright photo of a {c}.',\n    lambda c: f'a photo of a clean {c}.',\n    lambda c: f'a photo of a dirty {c}.',\n    lambda c: f'a dark photo of the {c}.',\n    lambda c: f'a drawing of a {c}.',\n    lambda c: f'a photo of my {c}.',\n    lambda c: f'the plastic {c}.',\n    lambda c: f'a photo of the cool {c}.',\n    lambda c: f'a close-up photo of a {c}.',\n    lambda c: f'a black and white photo of the {c}.',\n    lambda c: f'a painting of the {c}.',\n    lambda c: f'a painting of a {c}.',\n    lambda c: f'a pixelated photo of the {c}.',\n    lambda c: f'a sculpture of the {c}.',\n    lambda c: f'a bright photo of the {c}.',\n    lambda c: f'a cropped photo of a {c}.',\n    lambda c: f'a plastic {c}.',\n    lambda c: f'a photo of the dirty {c}.',\n    lambda c: f'a jpeg corrupted photo of a {c}.',\n    lambda c: f'a blurry photo of the {c}.',\n    lambda c: f'a photo of the {c}.',\n    lambda c: f'a good photo of the {c}.',\n    lambda c: f'a rendering of the {c}.',\n    lambda c: f'a {c} in a video game.',\n    lambda c: f'a photo of one {c}.',\n    lambda c: f'a doodle of a {c}.',\n    lambda c: f'a close-up photo of the {c}.',\n    lambda c: f'a photo of a {c}.',\n    lambda c: f'the origami {c}.',\n    lambda c: f'the {c} in a video game.',\n    lambda c: f'a sketch of a {c}.',\n    lambda c: f'a doodle of the {c}.',\n    lambda c: f'a origami {c}.',\n    lambda c: f'a low resolution photo of a {c}.',\n    lambda c: f'the toy {c}.',\n    lambda c: f'a rendition of the {c}.',\n    lambda c: f'a photo of the clean {c}.',\n    lambda c: f'a photo of a large {c}.',\n    lambda c: f'a rendition of a {c}.',\n    lambda c: f'a photo of a nice {c}.',\n    lambda c: f'a photo of a weird {c}.',\n    lambda c: f'a blurry photo of a {c}.',\n    lambda c: f'a cartoon {c}.',\n    lambda c: f'art of a {c}.',\n    lambda c: f'a sketch of the {c}.',\n    lambda c: f'a embroidered {c}.',\n    lambda c: f'a pixelated photo of a {c}.',\n    lambda c: f'itap of the {c}.',\n    lambda c: f'a jpeg corrupted photo of the {c}.',\n    lambda c: f'a good photo of a {c}.',\n    lambda c: f'a plushie {c}.',\n    lambda c: f'a photo of the nice {c}.',\n    lambda c: f'a photo of the small {c}.',\n    lambda c: f'a photo of the weird {c}.',\n    lambda c: f'the cartoon {c}.',\n    lambda c: f'art of the {c}.',\n    lambda c: f'a drawing of the {c}.',\n    lambda c: f'a photo of the large {c}.',\n    lambda c: f'a black and white photo of a {c}.',\n    lambda c: f'the plushie {c}.',\n    lambda c: f'a dark photo of a {c}.',\n    lambda c: f'itap of a {c}.',\n    lambda c: f'graffiti of the {c}.',\n    lambda c: f'a toy {c}.',\n    lambda c: f'itap of my {c}.',\n    lambda c: f'a photo of a cool {c}.',\n    lambda c: f'a photo of a small {c}.',\n    lambda c: f'a tattoo of the {c}.',\n)\n\n\n# a much smaller subset of above prompts\n# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb\nSIMPLE_IMAGENET_TEMPLATES = (\n    lambda c: f'itap of a {c}.',\n    lambda c: f'a bad photo of the {c}.',\n    lambda c: f'a origami {c}.',\n    lambda c: f'a photo of the large {c}.',\n    lambda c: f'a {c} in a video game.',\n    lambda c: f'art of the {c}.',\n    lambda c: f'a photo of the small {c}.',\n)\n\n\nIMAGENET_CLASSNAMES = (\n    \"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\",\n    \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\",\n    \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\",\n    \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\",\n    \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\",\n    \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\",\n    \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\",\n    \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\",\n    \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\",\n    \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\",\n    \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\",\n    \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\",\n    \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\",\n    \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\",\n    \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\",\n    \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\",\n    \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\",\n    \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\",\n    \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\",\n    \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\",\n    \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\",\n    \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\",\n    \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\",\n    \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\",\n    \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\",\n    \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\",\n    \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\",\n    \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\",\n    \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\",\n    \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\",\n    \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\",\n    \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\",\n    \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\",\n    \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\",\n    \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\",\n    \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\",\n    \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\",\n    \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\",\n    \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\",\n    \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\",\n    \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\",\n    \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\",\n    \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\",\n    \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\",\n    \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\",\n    \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\",\n    \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\",\n    \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\",\n    \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\",\n    \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\",\n    \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\",\n    \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\",\n    \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\",\n    \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\",\n    \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\",\n    \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\",\n    \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\",\n    \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\",\n    \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\",\n    \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\",\n    \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\",\n    \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\",\n    \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\",\n    \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\",\n    \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\",\n    \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\",\n    \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\",\n    \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\",\n    \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\",\n    \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\",\n    \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\",\n    \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\",\n    \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\",\n    \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\",\n    \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\",\n    \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\",\n    \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\",\n    \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\",\n    \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\",\n    \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\",\n    \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\",\n    \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\",\n    \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\",\n    \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\",\n    \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\",\n    \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\",\n    \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\",\n    \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\",\n    \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\",\n    \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\",\n    \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\",\n    \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\",\n    \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\",\n    \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\",\n    \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\",\n    \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\",\n    \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\",\n    \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\",\n    \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\",\n    \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\",\n    \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\",\n    \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\",\n    \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\",\n    \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\",\n    \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\",\n    \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\",\n    \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\",\n    \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\",\n    \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\",\n    \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\",\n    \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\",\n    \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\",\n    \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\",\n    \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\",\n    \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\",\n    \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\",\n    \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\",\n    \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\",\n    \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\",\n    \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\",\n    \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\",\n    \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\",\n    \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\",\n    \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\",\n    \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\",\n    \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\",\n    \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\",\n    \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\",\n    \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\",\n    \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\",\n    \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\",\n    \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\",\n    \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\",\n    \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\",\n    \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\",\n    \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\",\n    \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\",\n    \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\",\n    \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\",\n    \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\",\n    \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\",\n    \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\",\n    \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\",\n    \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\",\n    \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\",\n    \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\",\n    \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\",\n    \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\",\n    \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\",\n    \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\",\n    \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\",\n    \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\",\n    \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\",\n    \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\",\n    \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\",\n    \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\",\n    \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\",\n    \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\",\n    \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\",\n    \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\",\n    \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\",\n    \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\",\n    \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\",\n    \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\",\n    \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"\n)\n\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"pdm-backend\"]\nbuild-backend = \"pdm.backend\"\n\n[project]\nname = \"inf-cl\"\nversion = \"1.2\"\nauthors = [\n    {name = \"Zesen Cheng\", email = \"cyanlaser@stu.pku.edu.cn\"},\n    {name = \"Hang Zhang\"},\n    {name = \"Kehan Li\"},\n    {name = \"Xin Li\"},\n]\ndescription = \"A highly memory-efficient contrastive loss.\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nlicense = {text = \"MIT\"}\nclassifiers = [\n        'Development Status :: 4 - Beta',\n        'Intended Audience :: Education',\n        'Intended Audience :: Science/Research',\n        'License :: OSI Approved :: MIT License',\n        'Programming Language :: Python :: 3.8',\n        'Programming Language :: Python :: 3.9',\n        'Programming Language :: Python :: 3.10',\n        'Programming Language :: Python :: 3.11',\n        'Programming Language :: Python :: 3.12',\n        'Topic :: Scientific/Engineering',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n        'Topic :: Software Development',\n        'Topic :: Software Development :: Libraries',\n        'Topic :: Software Development :: Libraries :: Python Modules',\n]\ndependencies = [\n    'numpy',\n    'triton>=2.2.0',\n]\n\n[project.urls]\nHomepage = \"https://github.com/clownrat6/Inf-CLIP/inf_cl\"\nIssues = \"https://github.com/clownrat6/Inf-CLIP/issues\"\n\n[tool.pdm.build]\nexcludes = [\"./.git\"]\npackage-dir = \".\"\nincludes = [\"./inf_cl\"]\n"
  },
  {
    "path": "requirements.txt",
    "content": "--extra-index-url https://download.pytorch.org/whl/cu118\n# basic dependencies\ntorch==2.2.0\ntorchvision==0.17.0\nnumpy==1.24.4\ntimm\n# data processing\nwebdataset\npandas\nftfy\nregex\nbraceexpand\n# The newest pillow fix this bug: \"UserWarning: image file could not be identified because WEBP support not installed\"\npillow==10.4.0 # Refer to this issue: https://github.com/ContinuumIO/anaconda-issues/issues/10737\n# logging tools\ntensorboard\ntensorboardX"
  },
  {
    "path": "scripts/benchmarks_eval.sh",
    "content": "clip_benchmark eval \\\n    --model LiT-B-16 \\\n    --pretrained work_dirs/epoch_8.pt \\\n    --dataset datasets/imagenet.txt \\\n    --recall_k 1 5 10 \\\n    --dataset_root datasets/clip-benchmark/wds_{dataset_cleaned} \\\n    --output \"benchmark_{dataset}_{pretrained}_{model}_{language}_{task}.json\"\n"
  },
  {
    "path": "scripts/cc12m/clip_vit-b-32_bs32k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=32768\nLOCAL_BATCH_SIZE=512\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=40\nTRAIN_NUM_SAMPLES=10445970\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=clip_cc12m\nRUN_NAME=vit-b-32_bs32k_e40\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model ViT-B-32 \\\n    --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/IMAGE/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --lr 5e-4 \\\n    --beta1 0.9 \\\n    --beta2 0.98 \\\n    --eps 1.0e-8 \\\n    --wd 0.5 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 5 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\"
  },
  {
    "path": "scripts/cc12m/lit_vit-b-16_bs32k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=32768\nLOCAL_BATCH_SIZE=1024\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=20\nTRAIN_NUM_SAMPLES=10445970\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_cc12m\nRUN_NAME=lit-b-16_bs32k_e20\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-B-16 \\\n    --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/cc12m/lit_vit-b-32_bs32k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=32768\nLOCAL_BATCH_SIZE=1024\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=20\nTRAIN_NUM_SAMPLES=10445970\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_cc12m\nRUN_NAME=lit-b-32_bs32k_e20\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-B-32 \\\n    --train-data ${DATA_DIR}'/cc12m/{0000..1044}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/cc3m/clip_r50_bs4k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=4096\nLOCAL_BATCH_SIZE=256\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=40\nTRAIN_NUM_SAMPLES=3018714\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=clip_cc3m\nRUN_NAME=r50_bs4k_e40\nDATA_DIR=/mnt/damovl/MEDIA\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model RN50 \\\n    --train-data ${DATA_DIR}'/cc3m/{0000..0301}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --lr 5e-4 \\\n    --beta1 0.9 \\\n    --beta2 0.98 \\\n    --eps 1.0e-8 \\\n    --wd 0.5 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 5 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/cc3m/clip_vit-b-32_bs16k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=16384\nLOCAL_BATCH_SIZE=256\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=40\nTRAIN_NUM_SAMPLES=3018714\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=clip_cc3m\nRUN_NAME=vit-b-32_bs16k_e40\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model ViT-B-32 \\\n    --train-data ${DATA_DIR}'/cc3m/cc3m-train-{0000..0575}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --lr 5e-4 \\\n    --beta1 0.9 \\\n    --beta2 0.98 \\\n    --eps 1.0e-8 \\\n    --wd 0.5 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 5 \\\n    --log_dir $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/cc3m/lit_vit-b-32_bs16k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=16384\nLOCAL_BATCH_SIZE=256\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=20\nTRAIN_NUM_SAMPLES=3018714\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_cc3m\nRUN_NAME=lit-b-32_bs16k_e20\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-B-32 \\\n    --train-data ${DATA_DIR}'/cc3m/{0000..1044}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 32 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/imagenet_eval.sh",
    "content": "torchrun --nproc_per_node 1 \\\n    -m inf_cl_train.main \\\n    --imagenet-val datasets/imagenet-1k/val \\\n    --model ViT-B-16 \\\n    --pretrained openai \\\n    --workers 64 \\\n"
  },
  {
    "path": "scripts/laion400m/clip_vit-b-32_bs256k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=262144\nLOCAL_BATCH_SIZE=512\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=8\nTRAIN_NUM_SAMPLES=280321756\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=clip_laion400m\nRUN_NAME=vit-b-32_bs256k_e8\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model ViT-B-32 \\\n    --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --lr 5e-4 \\\n    --beta1 0.9 \\\n    --beta2 0.98 \\\n    --eps 1.0e-8 \\\n    --wd 0.5 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n"
  },
  {
    "path": "scripts/laion400m/lit_vit-b-16_bs256k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=262144\nLOCAL_BATCH_SIZE=512\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=8\nTRAIN_NUM_SAMPLES=280321756\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_laion400m\nRUN_NAME=lit-b-16_bs256k_e8\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-B-16 \\\n    --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/laion400m/lit_vit-b-32_bs256k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=262144\nLOCAL_BATCH_SIZE=512\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=8\nTRAIN_NUM_SAMPLES=280321756\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_laion400m\nRUN_NAME=lit-b-32_bs256k_e8\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-B-32 \\\n    --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "scripts/laion400m/lit_vit-l-16_bs256k.sh",
    "content": "# Environment Variables\nARG_WORLD_SIZE=${1:-1}\nARG_NPROC_PER_NODE=${2:-8}\nARG_MASTER_ADDR=\"127.0.0.1\"\nARG_MASTER_PORT=16666\nARG_RANK=${3:-0}\n\n# Multiple conditions\nif [ ! -n \"$WORLD_SIZE\" ] || [ ! -n \"$NPROC_PER_NODE\" ]; then\n    WORLD_SIZE=$ARG_WORLD_SIZE\n    NPROC_PER_NODE=$ARG_NPROC_PER_NODE\nfi\nif [ ! -n \"$MASTER_ADDR\" ] || [ ! -n \"$MASTER_PORT\" ] || [ ! -n \"$RANK\" ]; then\n    MASTER_ADDR=$ARG_MASTER_ADDR\n    MASTER_PORT=$ARG_MASTER_PORT\n    RANK=$ARG_RANK\nfi\n\necho \"WORLD_SIZE: $WORLD_SIZE\"\necho \"NPROC_PER_NODE: $NPROC_PER_NODE\"\n\n# Training Arguments\nGLOBAL_BATCH_SIZE=262144\nLOCAL_BATCH_SIZE=512\nACCUMULATION_STEPS=$[$GLOBAL_BATCH_SIZE/($WORLD_SIZE*$NPROC_PER_NODE*$LOCAL_BATCH_SIZE)]\nEPOCHS=8\nTRAIN_NUM_SAMPLES=280321756\nWARMUP_STEPS=$[$TRAIN_NUM_SAMPLES/(2*$GLOBAL_BATCH_SIZE)]\necho \"ACCUMULATION_STEPS: $ACCUMULATION_STEPS\"\n\n# Log Arguments\nexport TRANSFORMERS_OFFLINE=1\nexport WANDB_PROJECT=lit_laion400m\nRUN_NAME=lit-l-16_bs256k_e8\nDATA_DIR=datasets\nOUTP_DIR=work_dirs\n\n\ntorchrun --nnodes $WORLD_SIZE \\\n    --nproc_per_node $NPROC_PER_NODE \\\n    --master_addr=$MASTER_ADDR \\\n    --master_port=$MASTER_PORT \\\n    --node_rank $RANK \\\n    -m inf_clip.train.main \\\n    --model LiT-L-16 \\\n    --train-data ${DATA_DIR}'/laion400m/{00000..41407}.tar' \\\n    --train-num-samples $TRAIN_NUM_SAMPLES \\\n    --aug-cfg scale='(0.08, 1.0)'\\\n    --dataset-type webdataset \\\n    --imagenet-val ${DATA_DIR}/imagenet-1k/val \\\n    --epochs $EPOCHS \\\n    --warmup $WARMUP_STEPS \\\n    --batch-size $LOCAL_BATCH_SIZE \\\n    --accum-freq $ACCUMULATION_STEPS \\\n    --optim adafactor \\\n    --lr 1e-3 \\\n    --beta1 0.9 \\\n    --beta2 0.95 \\\n    --eps 1.0e-8 \\\n    --wd 1e-4 \\\n    --grad-clip-norm 1.0 \\\n    --workers 16 \\\n    --precision amp \\\n    --infloss \\\n    --log-every-n-steps 1 \\\n    --logs $OUTP_DIR/$WANDB_PROJECT \\\n    --name $RUN_NAME \\\n    --save-frequency 1 \\\n    --zeroshot-frequency 1 \\\n    --report-to tensorboard \\\n    --resume latest \\\n"
  },
  {
    "path": "tests/example.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport numpy as np\n\nfrom inf_cl import cal_inf_loss\n\n\ndef create_cl_tensors(rank, world_size):\n    # Parameters\n    dtype = torch.float32\n    num_heads = 3        # Number of attention heads\n    seq_length_q = 32768 # Sequence length\n    seq_length_k = 32768\n    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)\n\n    # Randomly initialize inputs\n    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda:{rank}\")\n    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f\"cuda:{rank}\")\n    l = torch.ones([], dtype=dtype, device=f\"cuda:{rank}\") * np.log(1 / 0.07)\n\n    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query\n    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key\n    l = l.requires_grad_() # Logit scale\n\n    return q, k, l\n\n\nif __name__ == \"__main__\":\n    # Assume that the distributed environment has been initialized\n    dist.init_process_group(\"nccl\")\n\n    rank = dist.get_rank()\n    world_size = dist.get_world_size()\n\n    torch.cuda.set_device(rank)\n\n    # Exampled by Image-Text Contrastive Learning, q is the global image features, \n    # k is the text features, and l is the logit scale.\n    q, k, l = create_cl_tensors(rank, world_size)\n\n    # labels are diagonal elements by default. \n    # labels = torch.arange(q.shape[0])\n    loss = cal_inf_loss(q, k, scale=l.exp())\n\n    print(loss)\n"
  }
]