Full Code of ByChelsea/VAND-APRIL-GAN for AI

master f13b8a634e04 cached
90 files
24.2 MB
58.7k tokens
227 symbols
1 requests
Download .txt
Showing preview only (236K chars total). Download the full file or copy to clipboard to get everything.
Repository: ByChelsea/VAND-APRIL-GAN
Branch: master
Commit: f13b8a634e04
Files: 90
Total size: 24.2 MB

Directory structure:
gitextract_7tq5j4qy/

├── LICENSE
├── README.md
├── data/
│   ├── mvtec.py
│   └── visa.py
├── dataset.py
├── exps/
│   └── pretrained/
│       ├── mvtec_pretrained.pth
│       └── visa_pretrained.pth
├── few_shot.py
├── loss.py
├── model.py
├── open_clip/
│   ├── __init__.py
│   ├── coca_model.py
│   ├── constants.py
│   ├── factory.py
│   ├── generation_utils.py
│   ├── hf_configs.py
│   ├── hf_model.py
│   ├── loss.py
│   ├── model.py
│   ├── model_configs/
│   │   ├── RN101-quickgelu.json
│   │   ├── RN101.json
│   │   ├── RN50-quickgelu.json
│   │   ├── RN50.json
│   │   ├── RN50x16.json
│   │   ├── RN50x4.json
│   │   ├── RN50x64.json
│   │   ├── ViT-B-16-plus-240.json
│   │   ├── ViT-B-16-plus.json
│   │   ├── ViT-B-16.json
│   │   ├── ViT-B-32-plus-256.json
│   │   ├── ViT-B-32-quickgelu.json
│   │   ├── ViT-B-32.json
│   │   ├── ViT-H-14.json
│   │   ├── ViT-H-16.json
│   │   ├── ViT-L-14-280.json
│   │   ├── ViT-L-14-336.json
│   │   ├── ViT-L-14.json
│   │   ├── ViT-L-16-320.json
│   │   ├── ViT-L-16.json
│   │   ├── ViT-M-16-alt.json
│   │   ├── ViT-M-16.json
│   │   ├── ViT-M-32-alt.json
│   │   ├── ViT-M-32.json
│   │   ├── ViT-S-16-alt.json
│   │   ├── ViT-S-16.json
│   │   ├── ViT-S-32-alt.json
│   │   ├── ViT-S-32.json
│   │   ├── ViT-bigG-14.json
│   │   ├── ViT-e-14.json
│   │   ├── ViT-g-14.json
│   │   ├── coca_ViT-B-32.json
│   │   ├── coca_ViT-L-14.json
│   │   ├── coca_base.json
│   │   ├── coca_roberta-ViT-B-32.json
│   │   ├── convnext_base.json
│   │   ├── convnext_base_w.json
│   │   ├── convnext_base_w_320.json
│   │   ├── convnext_large.json
│   │   ├── convnext_large_d.json
│   │   ├── convnext_large_d_320.json
│   │   ├── convnext_small.json
│   │   ├── convnext_tiny.json
│   │   ├── convnext_xlarge.json
│   │   ├── convnext_xxlarge.json
│   │   ├── convnext_xxlarge_320.json
│   │   ├── mt5-base-ViT-B-32.json
│   │   ├── mt5-xl-ViT-H-14.json
│   │   ├── roberta-ViT-B-32.json
│   │   ├── swin_base_patch4_window7_224.json
│   │   ├── vit_medium_patch16_gap_256.json
│   │   ├── vit_relpos_medium_patch16_cls_224.json
│   │   ├── xlm-roberta-base-ViT-B-32.json
│   │   └── xlm-roberta-large-ViT-H-14.json
│   ├── modified_resnet.py
│   ├── openai.py
│   ├── pretrained.py
│   ├── push_to_hf_hub.py
│   ├── timm_model.py
│   ├── tokenizer.py
│   ├── transform.py
│   ├── transformer.py
│   ├── utils.py
│   └── version.py
├── prompt_ensemble.py
├── requirements.txt
├── test.py
├── test_few_shot.sh
├── test_zero_shot.sh
├── train.py
└── train.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 Xuhai Chen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
[Workshop Link](https://sites.google.com/view/vand-cvpr23/home) | [Challenge Link](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0) | [Report Paper](https://arxiv.org/abs/2305.17382)
---
[Xuhai Chen](https://scholar.google.com.hk/citations?user=LU4etJ0AAAAJ&hl=zh-CN&authuser=1) · [Yue Han](https://scholar.google.com/citations?hl=en&user=08E500gAAAAJ&view_op=list_works&gmla=AHoSzlVzTXnclaPp9h1g8xAZQBrsxdFXvhunMA3AmRm_GSLnZA1956xavx6hmPaCFCysonsXeTQyhB_cokdUFacUc5HBunMPW-uOtLZLTTufiZiHB6hAVgr9l7cJ_UHKeQ) · [Jiangning Zhang](https://zhangzjn.github.io/)

This repository contains the official PyTorch implementation of [Zero-/Few-shot Anomaly Classification and Segmentation Method](https://arxiv.org/abs/2305.17382) used in the [CVPR 2023 VAND Challenge](https://sites.google.com/view/vand-cvpr23/challenge?authuser=0), which can be viewd as an improved version of [WinCLIP](https://arxiv.org/abs/2303.14814). We achieve **Winner** in the Zero-shot Track and **Honorable Mentions** in the Few-shot Track.

<img src="illustration/main.png" alt="Model Structure" style="max-width: 50px; height: auto;">

**Results on the Challenge official test set**

<img src="illustration/results.png" alt="Model Structure" style="max-width: 50px; height: auto;">

## Installation

- Prepare experimental environments

  ```shell
  pip install -r requirements.txt
  ```
  
## Dataset Preparation 
### MVTec AD
- Download and extract [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) into `data/mvtec`
- run`python data/mvtec.py` to obtain `data/mvtec/meta.json`
```
data
├── mvtec
    ├── meta.json
    ├── bottle
        ├── train
            ├── good
                ├── 000.png
        ├── test
            ├── good
                ├── 000.png
            ├── anomaly1
                ├── 000.png
        ├── ground_truth
            ├── anomaly1
                ├── 000.png
```

### VisA
- Download and extract [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar) into `data/visa`
- run`python data/visa.py` to obtain `data/visa/meta.json`
```
data
├── visa
    ├── meta.json
    ├── candle
        ├── Data
            ├── Images
                ├── Anomaly
                    ├── 000.JPG
                ├── Normal
                    ├── 0000.JPG
            ├── Masks
                ├── Anomaly
                    ├── 000.png
```

## Train
Set parameters in `train.sh`.
- `train_data_path`: the path to the training dataset
- `dataset`: name of the training dataset, optional: mvtec, visa
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `image_size`: the size of the images inputted into the CLIP model
- `aug_rate`: the probability of stitching images (only for mvtec)

Then run the following command
  ```shell
  sh train.sh
  ```

## Test
### Pretrained Models
We provide our pre-trained models in `exps/pretrained`, where `mvtec_pretrained.pth` represents the model trained on the MVTec AD dataset and `visa_pretrained.pth` represents the model trained on the VisA dataset.

Set parameters in `test_zero_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model

Then, run the following command to test them in the zero-shot setting:
  ```shell
  sh test_zero_shot.sh
  ```
  
Set parameters in `test_few_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `k_shot`: different number of reference images

Then, run the following command to test them in the few-shot setting:
  ```shell
  sh test_few_shot.sh
  ```

### Zero-shot Setting
Set parameters in `test_zero_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `image_size`: the size of the images inputted into the CLIP model
- `mode`: zero shot or few shot

Then run the following command
  ```shell
  sh test_zero_shot.sh
  ```

### Few-shot Setting
Set parameters in `test_few_shot.sh`.
- `data_path`: the path to the test dataset
- `dataset`: name of the test dataset, optional: mvtec, visa
- `checkpoint_path`: the path to the test model
- `model`: the CLIP model
- `pretrained`: the pretrained weights
- `features_list`: features to be mapped into the joint embedding space
- `few_shot_features`: features stored in the memory banks
- `image_size`: the size of the images inputted into the CLIP model
- `mode`: zero shot or few shot
- `k_shot`: different number of reference images
- `seed`: the random seed

Then run the following command
  ```shell
  sh test_few_shot.sh
  ```

## Citation
If our work is helpful for your research, please consider citing:

```
@article{chen2023zero,
  title={A Zero-/Few-Shot Anomaly Classification and Segmentation Method for CVPR 2023 VAND Workshop Challenge Tracks 1\&2: 1st Place on Zero-shot AD and 4th Place on Few-shot AD},
  author={Chen, Xuhai and Han, Yue and Zhang, Jiangning},
  journal={arXiv preprint arXiv:2305.17382},
  year={2023}
}
```

## Acknowledgements
We thank [WinCLIP: Zero-/Few-Shot Anomaly Classification and Segmentation](https://arxiv.org/abs/2303.14814) for providing assistance for our research.



================================================
FILE: data/mvtec.py
================================================
import os
import json


class MVTecSolver(object):
    CLSNAMES = [
        'bottle', 'cable', 'capsule', 'carpet', 'grid',
        'hazelnut', 'leather', 'metal_nut', 'pill', 'screw',
        'tile', 'toothbrush', 'transistor', 'wood', 'zipper',
    ]

    def __init__(self, root='data/mvtec'):
        self.root = root
        self.meta_path = f'{root}/meta.json'

    def run(self):
        info = dict(train={}, test={})
        for cls_name in self.CLSNAMES:
            cls_dir = f'{self.root}/{cls_name}'
            for phase in ['train', 'test']:
                cls_info = []
                species = os.listdir(f'{cls_dir}/{phase}')
                for specie in species:
                    is_abnormal = True if specie not in ['good'] else False
                    img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
                    mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
                    img_names.sort()
                    mask_names.sort() if mask_names is not None else None
                    for idx, img_name in enumerate(img_names):
                        info_img = dict(
                            img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
                            mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
                            cls_name=cls_name,
                            specie_name=specie,
                            anomaly=1 if is_abnormal else 0,
                        )
                        cls_info.append(info_img)
                info[phase][cls_name] = cls_info
        with open(self.meta_path, 'w') as f:
            f.write(json.dumps(info, indent=4) + "\n")

if __name__ == '__main__':
    runner = MVTecSolver(root='data/mvtec')
    runner.run()


================================================
FILE: data/visa.py
================================================
import os
import json
import pandas as pd


class VisASolver(object):
    CLSNAMES = [
        'candle', 'capsules', 'cashew', 'chewinggum', 'fryum',
        'macaroni1', 'macaroni2', 'pcb1', 'pcb2', 'pcb3',
        'pcb4', 'pipe_fryum',
    ]

    def __init__(self, root='data/visa'):
        self.root = root
        self.meta_path = f'{root}/meta.json'
        self.phases = ['train', 'test']
        self.csv_data = pd.read_csv(f'{root}/split_csv/1cls.csv', header=0)

    def run(self):
        columns = self.csv_data.columns  # [object, split, label, image, mask]
        info = {phase: {} for phase in self.phases}
        for cls_name in self.CLSNAMES:
            cls_data = self.csv_data[self.csv_data[columns[0]] == cls_name]
            for phase in self.phases:
                cls_info = []
                cls_data_phase = cls_data[cls_data[columns[1]] == phase]
                cls_data_phase.index = list(range(len(cls_data_phase)))
                for idx in range(cls_data_phase.shape[0]):
                    data = cls_data_phase.loc[idx]
                    is_abnormal = True if data[2] == 'anomaly' else False
                    info_img = dict(
                        img_path=data[3],
                        mask_path=data[4] if is_abnormal else '',
                        cls_name=cls_name,
                        specie_name='',
                        anomaly=1 if is_abnormal else 0,
                    )
                    cls_info.append(info_img)
                info[phase][cls_name] = cls_info
        with open(self.meta_path, 'w') as f:
            f.write(json.dumps(info, indent=4) + "\n")


if __name__ == '__main__':
    runner = VisASolver(root='data/visa')
    runner.run()


================================================
FILE: dataset.py
================================================
import torch.utils.data as data
import json
import random
from PIL import Image
import numpy as np
import torch
import os


class VisaDataset(data.Dataset):
	def __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None):
		self.root = root
		self.transform = transform
		self.target_transform = target_transform

		self.data_all = []
		meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
		name = self.root.split('/')[-1]
		meta_info = meta_info[mode]

		if mode == 'train':
			self.cls_names = [obj_name]
			save_dir = os.path.join(save_dir, 'k_shot.txt')
		else:
			self.cls_names = list(meta_info.keys())
		for cls_name in self.cls_names:
			if mode == 'train':
				data_tmp = meta_info[cls_name]
				indices = torch.randint(0, len(data_tmp), (k_shot,))
				for i in range(len(indices)):
					self.data_all.append(data_tmp[indices[i]])
					with open(save_dir, "a") as f:
						f.write(data_tmp[indices[i]]['img_path'] + '\n')
			else:
				self.data_all.extend(meta_info[cls_name])
		self.length = len(self.data_all)

	def __len__(self):
		return self.length

	def get_cls_names(self):
		return self.cls_names

	def __getitem__(self, index):
		data = self.data_all[index]
		img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
															  data['specie_name'], data['anomaly']
		img = Image.open(os.path.join(self.root, img_path))
		if anomaly == 0:
			img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
		else:
			img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
			img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
		img = self.transform(img) if self.transform is not None else img
		img_mask = self.target_transform(
			img_mask) if self.target_transform is not None and img_mask is not None else img_mask
		img_mask = [] if img_mask is None else img_mask

		return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
				'img_path': os.path.join(self.root, img_path)}


class MVTecDataset(data.Dataset):
	def __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None):
		self.root = root
		self.transform = transform
		self.target_transform = target_transform
		self.aug_rate = aug_rate

		self.data_all = []
		meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
		name = self.root.split('/')[-1]
		meta_info = meta_info[mode]

		if mode == 'train':
			self.cls_names = [obj_name]
			save_dir = os.path.join(save_dir, 'k_shot.txt')
		else:
			self.cls_names = list(meta_info.keys())
		for cls_name in self.cls_names:
			if mode == 'train':
				data_tmp = meta_info[cls_name]
				indices = torch.randint(0, len(data_tmp), (k_shot,))
				for i in range(len(indices)):
					self.data_all.append(data_tmp[indices[i]])
					with open(save_dir, "a") as f:
						f.write(data_tmp[indices[i]]['img_path'] + '\n')
			else:
				self.data_all.extend(meta_info[cls_name])
		self.length = len(self.data_all)

	def __len__(self):
		return self.length

	def get_cls_names(self):
		return self.cls_names

	def combine_img(self, cls_name):
		img_paths = os.path.join(self.root, cls_name, 'test')
		img_ls = []
		mask_ls = []
		for i in range(4):
			defect = os.listdir(img_paths)
			random_defect = random.choice(defect)
			files = os.listdir(os.path.join(img_paths, random_defect))
			random_file = random.choice(files)
			img_path = os.path.join(img_paths, random_defect, random_file)
			mask_path = os.path.join(self.root, cls_name, 'ground_truth', random_defect, random_file[:3] + '_mask.png')
			img = Image.open(img_path)
			img_ls.append(img)
			if random_defect == 'good':
				img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
			else:
				img_mask = np.array(Image.open(mask_path).convert('L')) > 0
				img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
			mask_ls.append(img_mask)
		# image
		image_width, image_height = img_ls[0].size
		result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
		for i, img in enumerate(img_ls):
			row = i // 2
			col = i % 2
			x = col * image_width
			y = row * image_height
			result_image.paste(img, (x, y))

		# mask
		result_mask = Image.new("L", (2 * image_width, 2 * image_height))
		for i, img in enumerate(mask_ls):
			row = i // 2
			col = i % 2
			x = col * image_width
			y = row * image_height
			result_mask.paste(img, (x, y))

		return result_image, result_mask

	def __getitem__(self, index):
		data = self.data_all[index]
		img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
															  data['specie_name'], data['anomaly']
		random_number = random.random()
		if random_number < self.aug_rate:
			img, img_mask = self.combine_img(cls_name)
		else:
			img = Image.open(os.path.join(self.root, img_path))
			if anomaly == 0:
				img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
			else:
				img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
				img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
		# transforms
		img = self.transform(img) if self.transform is not None else img
		img_mask = self.target_transform(
			img_mask) if self.target_transform is not None and img_mask is not None else img_mask
		img_mask = [] if img_mask is None else img_mask
		return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
				'img_path': os.path.join(self.root, img_path)}


================================================
FILE: exps/pretrained/mvtec_pretrained.pth
================================================
[File too large to display: 12.0 MB]

================================================
FILE: exps/pretrained/visa_pretrained.pth
================================================
[File too large to display: 12.0 MB]

================================================
FILE: few_shot.py
================================================
import torch
from dataset import VisaDataset, MVTecDataset

def memory(model_name, model, obj_list, dataset_dir, save_path, preprocess, transform, k_shot,
           few_shot_features, dataset_name, device):
    mem_features = {}
    for obj in obj_list:
        if dataset_name == 'mvtec':
            data = MVTecDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
                                aug_rate=-1, mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)
        else:
            data = VisaDataset(root=dataset_dir, transform=preprocess, target_transform=transform,
                               mode='train', k_shot=k_shot, save_dir=save_path, obj_name=obj)
        dataloader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
        features = []
        for items in dataloader:
            image = items['img'].to(device)
            with torch.no_grad():
                image_features, patch_tokens = model.encode_image(image, few_shot_features)
                if 'ViT' in model_name:
                    patch_tokens = [p[0, 1:, :] for p in patch_tokens]
                else:
                    patch_tokens = [p[0].view(p.shape[1], -1).permute(1, 0).contiguous() for p in patch_tokens]
                features.append(patch_tokens)
        mem_features[obj] = [torch.cat(
            [features[j][i] for j in range(len(features))], dim=0) for i in range(len(features[0]))]
    return mem_features

================================================
FILE: loss.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import exp

class FocalLoss(nn.Module):
    """
    copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
    This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
    'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
        Focal_Loss= -1*alpha*(1-pt)*log(pt)
    :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
    :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
                    focus on hard misclassified example
    :param smooth: (float,double) smooth value when cross entropy
    :param balance_index: (int) balance class index, should be specific when alpha is float
    :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
    """

    def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
        super(FocalLoss, self).__init__()
        self.apply_nonlin = apply_nonlin
        self.alpha = alpha
        self.gamma = gamma
        self.balance_index = balance_index
        self.smooth = smooth
        self.size_average = size_average

        if self.smooth is not None:
            if self.smooth < 0 or self.smooth > 1.0:
                raise ValueError('smooth value should be in [0,1]')

    def forward(self, logit, target):
        if self.apply_nonlin is not None:
            logit = self.apply_nonlin(logit)
        num_class = logit.shape[1]

        if logit.dim() > 2:
            # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
            logit = logit.view(logit.size(0), logit.size(1), -1)
            logit = logit.permute(0, 2, 1).contiguous()
            logit = logit.view(-1, logit.size(-1))
        target = torch.squeeze(target, 1)
        target = target.view(-1, 1)
        alpha = self.alpha

        if alpha is None:
            alpha = torch.ones(num_class, 1)
        elif isinstance(alpha, (list, np.ndarray)):
            assert len(alpha) == num_class
            alpha = torch.FloatTensor(alpha).view(num_class, 1)
            alpha = alpha / alpha.sum()
        elif isinstance(alpha, float):
            alpha = torch.ones(num_class, 1)
            alpha = alpha * (1 - self.alpha)
            alpha[self.balance_index] = self.alpha

        else:
            raise TypeError('Not support alpha type')

        if alpha.device != logit.device:
            alpha = alpha.to(logit.device)

        idx = target.cpu().long()

        one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
        one_hot_key = one_hot_key.scatter_(1, idx, 1)
        if one_hot_key.device != logit.device:
            one_hot_key = one_hot_key.to(logit.device)

        if self.smooth:
            one_hot_key = torch.clamp(
                one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
        pt = (one_hot_key * logit).sum(1) + self.smooth
        logpt = pt.log()

        gamma = self.gamma

        alpha = alpha[idx]
        alpha = torch.squeeze(alpha)
        loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt

        if self.size_average:
            loss = loss.mean()
        return loss


class BinaryDiceLoss(nn.Module):
    def __init__(self):
        super(BinaryDiceLoss, self).__init__()

    def forward(self, input, targets):
        # 获取每个批次的大小 N
        N = targets.size()[0]
        # 平滑变量
        smooth = 1
        # 将宽高 reshape 到同一纬度
        input_flat = input.view(N, -1)
        targets_flat = targets.view(N, -1)

        # 计算交集
        intersection = input_flat * targets_flat
        N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
        # 计算一个批次中平均每张图的损失
        loss = 1 - N_dice_eff.sum() / N
        return loss


================================================
FILE: model.py
================================================
from torch import Tensor, nn
import torch
from torch.nn import functional as F

class LinearLayer(nn.Module):
    def __init__(self, dim_in, dim_out, k, model):
        super(LinearLayer, self).__init__()
        if 'ViT' in model:
            self.fc = nn.ModuleList([nn.Linear(dim_in, dim_out) for i in range(k)])
        else:
            self.fc = nn.ModuleList([nn.Linear(dim_in * 2 ** (i + 2), dim_out) for i in range(k)])

    def forward(self, tokens):
        for i in range(len(tokens)):
            if len(tokens[i].shape) == 3:
                tokens[i] = self.fc[i](tokens[i][:, 1:, :])
            else:
                B, C, H, W = tokens[i].shape
                tokens[i] = self.fc[i](tokens[i].view(B, C, -1).permute(0, 2, 1).contiguous())
        return tokens


================================================
FILE: open_clip/__init__.py
================================================
from .coca_model import CoCa
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \
    convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
from .openai import load_openai_model, list_openai_models
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \
    get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub
from .tokenizer import SimpleTokenizer, tokenize, decode
from .transform import image_transform, AugmentationCfg


================================================
FILE: open_clip/coca_model.py
================================================
from typing import Optional

import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from dataclasses import dataclass

from .transformer import (
    LayerNormFp32,
    LayerNorm,
    QuickGELU,
    MultimodalTransformer,
)
from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower

try:
    from transformers import (
        BeamSearchScorer,
        LogitsProcessorList,
        TopPLogitsWarper,
        TopKLogitsWarper,
        RepetitionPenaltyLogitsProcessor,
        MinLengthLogitsProcessor,
        MaxLengthCriteria,
        StoppingCriteriaList
    )

    GENERATION_TYPES = {
        "top_k": TopKLogitsWarper,
        "top_p": TopPLogitsWarper,
        "beam_search": "beam_search"
    }
    _has_transformers = True
except ImportError as e:
    GENERATION_TYPES = {
        "top_k": None,
        "top_p": None,
        "beam_search": "beam_search"
    }
    _has_transformers = False


@dataclass
class MultimodalCfg(CLIPTextCfg):
    mlp_ratio: int = 4
    dim_head: int = 64
    heads: int = 8
    n_queries: int = 256
    attn_pooler_heads: int = 8


def _build_text_decoder_tower(
        embed_dim,
        multimodal_cfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
    act_layer = QuickGELU if quick_gelu else nn.GELU
    norm_layer = (
        LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
    )

    decoder = MultimodalTransformer(
        context_length=multimodal_cfg.context_length,
        width=multimodal_cfg.width,
        heads=multimodal_cfg.heads,
        layers=multimodal_cfg.layers,
        ls_init_value=multimodal_cfg.ls_init_value,
        output_dim=embed_dim,
        act_layer=act_layer,
        norm_layer=norm_layer,
    )

    return decoder


class CoCa(nn.Module):
    def __init__(
            self,
            embed_dim,
            multimodal_cfg: MultimodalCfg,
            text_cfg: CLIPTextCfg,
            vision_cfg: CLIPVisionCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            pad_id: int = 0,
    ):
        super().__init__()
        multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
        text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
        vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg

        self.text = _build_text_tower(
            embed_dim=embed_dim,
            text_cfg=text_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        vocab_size = (
            text_cfg.vocab_size  # for hf models
            if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
            else text_cfg.vocab_size
        )

        self.visual = _build_vision_tower(
            embed_dim=embed_dim,
            vision_cfg=vision_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        self.text_decoder = _build_text_decoder_tower(
            vocab_size,
            multimodal_cfg=multimodal_cfg,
            quick_gelu=quick_gelu,
            cast_dtype=cast_dtype,
        )

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.pad_id = pad_id

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)
        self.text_decoder.set_grad_checkpointing(enable)

    # def _encode_image(self, images, out_layers, normalize=True):
    #     image_latent, tokens_embs = self.visual(images, out_layers)
    #     image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent
    #     return image_latent, tokens_embs
    def _encode_image(self, images, out_layers, normalize=True):
        image_latent = self.visual(images, out_layers)
        return image_latent

    def _encode_text(self, text, normalize=True, embed_cls=True):
        text = text[:, :-1] if embed_cls else text # make space for CLS token
        text_latent, token_emb = self.text(text)
        text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
        return text_latent, token_emb

    # def encode_image(self, images, out_layers, normalize=True):
    #     image_latent, _ = self._encode_image(images, out_layers, normalize=normalize)
    #     return image_latent
    def encode_image(self, images, out_layers, normalize=True):
        image_latent = self._encode_image(images, out_layers, normalize=normalize)
        return image_latent

    def encode_text(self, text, normalize=True, embed_cls=True):
        text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls)
        return text_latent

    def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None):
        text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
        if image_latent is None or image_embs is None:
            image_latent, image_embs = self._encode_image(image)

        # TODO: add assertion to avoid bugs?
        labels = text[:, -token_embs.shape[1]:]

        logits = self.text_decoder(image_embs, token_embs)
        return {
            "image_features": image_latent,
            "text_features": text_latent,
            "logits": logits,
            "labels": labels,
            "logit_scale": self.logit_scale.exp()
        }

    def generate(
        self,
        image,
        text=None,
        seq_len=30,
        max_seq_len=77,
        temperature=1.,
        generation_type="beam_search",
        top_p=0.1,  # keep tokens in the 1 - top_p quantile
        top_k=1,  # keeps the top_k most probable tokens
        pad_token_id=None,
        eos_token_id=None,
        sot_token_id=None,
        num_beams=6,
        num_beam_groups=3,
        min_seq_len=5,
        stopping_criteria=None,
        repetition_penalty=1.0,
        fixed_output_length=False # if True output.shape == (batch_size, seq_len)
    ):
        # taking many ideas and components from HuggingFace GenerationMixin
        # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
        assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
        assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"

        with torch.no_grad():
            sot_token_id = 49406 if sot_token_id is None else sot_token_id
            eos_token_id = 49407 if eos_token_id is None else eos_token_id
            pad_token_id = self.pad_id if pad_token_id is None else pad_token_id
            logit_processor = LogitsProcessorList(
                [
                    MinLengthLogitsProcessor(min_seq_len, eos_token_id),
                    RepetitionPenaltyLogitsProcessor(repetition_penalty),
                ]
            )

            if stopping_criteria is None:
                stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]

            stopping_criteria = StoppingCriteriaList(
                stopping_criteria
            )

            device = image.device

            if generation_type == "beam_search":
                output = self._generate_beamsearch(
                    image_inputs = image,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                    sot_token_id=sot_token_id,
                    num_beams=num_beams,
                    num_beam_groups=num_beam_groups,
                    min_seq_len=min_seq_len,
                    stopping_criteria=stopping_criteria,
                    logit_processor=logit_processor,
                )
                if fixed_output_length and output.shape[1] < seq_len:
                    return torch.cat(
                        (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id),
                        dim=1
                    )
                return output

            elif generation_type == "top_p":
                logit_warper = GENERATION_TYPES[generation_type](top_p)
            elif generation_type == "top_k":
                logit_warper = GENERATION_TYPES[generation_type](top_k)
            else:
                raise ValueError(
                    f"generation_type has to be one of "
                    f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
                )

            image_latent, image_embs = self._encode_image(image)

            if text is None:
                text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id

            was_training = self.training
            num_dims = len(text.shape)

            if num_dims == 1:
                text = text[None, :]

            cur_len = text.shape[1]
            self.eval()
            out = text

            while True:
                x = out[:, -max_seq_len:]
                cur_len = x.shape[1]
                logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1]
                mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
                sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id

                if mask.all():
                    if not fixed_output_length:
                        break
                else:
                    logits = logits[~mask, :]
                    filtered_logits = logit_processor(x[~mask, :], logits)
                    filtered_logits = logit_warper(x[~mask, :], filtered_logits)
                    probs = F.softmax(filtered_logits / temperature, dim=-1)

                    if (cur_len + 1 == seq_len):
                        sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
                    else:
                        sample[~mask, :] = torch.multinomial(probs, 1)

                out = torch.cat((out, sample), dim=-1)

                cur_len += 1

                if stopping_criteria(out, None):
                    break

            if num_dims == 1:
                out = out.squeeze(0)

            self.train(was_training)
            return out

    def _generate_beamsearch(
            self,
            image_inputs,
            pad_token_id=None,
            eos_token_id=None,
            sot_token_id=None,
            num_beams=6,
            num_beam_groups=3,
            min_seq_len=5,
            stopping_criteria=None,
            logit_processor=None,
            logit_warper=None,
    ):
        device = image_inputs.device
        batch_size = image_inputs.shape[0]
        image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)
        image_latent, image_embs = self._encode_image(image_inputs)

        input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
        input_ids = input_ids * sot_token_id
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device=device,
            num_beam_groups=num_beam_groups,
        )
        # instantiate logits processors
        logits_processor = (
            LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
            if logit_processor is None
            else logit_processor
        )

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams
        num_beam_groups = beam_scorer.num_beam_groups
        num_sub_beams = num_beams // num_beam_groups
        batch_beam_size, cur_len = input_ids.shape
        beam_indices = None

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
        # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
        # the same group don't produce same tokens everytime.
        beam_scores[:, ::num_sub_beams] = 0
        beam_scores = beam_scores.view((batch_size * num_beams,))

        while True:

            # predicted tokens in cur_len step
            current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)

            # indices which will form the beams in the next time step
            reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)

            # do one decoder step on all beams of all sentences in batch
            model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)
            outputs = self(
                model_inputs['images'],
                model_inputs['text'],
                embed_cls=False,
                image_latent=image_latent,
                image_embs=image_embs
            )

            for beam_group_idx in range(num_beam_groups):
                group_start_idx = beam_group_idx * num_sub_beams
                group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
                group_size = group_end_idx - group_start_idx

                # indices of beams of current group among all sentences in batch
                batch_group_indices = []

                for batch_idx in range(batch_size):
                    batch_group_indices.extend(
                        [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
                    )
                group_input_ids = input_ids[batch_group_indices]

                # select outputs of beams of currentg group only
                next_token_logits = outputs['logits'][batch_group_indices, -1, :]
                vocab_size = next_token_logits.shape[-1]

                next_token_scores_processed = logits_processor(
                    group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
                )
                next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
                next_token_scores = next_token_scores.expand_as(next_token_scores_processed)

                # reshape for beam search
                next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)

                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
                )

                next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
                next_tokens = next_tokens % vocab_size

                # stateless
                process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
                beam_outputs = beam_scorer.process(
                    group_input_ids,
                    next_token_scores,
                    next_tokens,
                    next_indices,
                    pad_token_id=pad_token_id,
                    eos_token_id=eos_token_id,
                    beam_indices=process_beam_indices,
                )
                beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
                beam_next_tokens = beam_outputs["next_beam_tokens"]
                beam_idx = beam_outputs["next_beam_indices"]

                input_ids[batch_group_indices] = group_input_ids[beam_idx]
                group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
                current_tokens[batch_group_indices] = group_input_ids[:, -1]

                # (beam_idx // group_size) -> batch_idx
                # (beam_idx % group_size) -> offset of idx inside the group
                reordering_indices[batch_group_indices] = (
                    num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
                )

            input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)

            # increase cur_len
            cur_len = cur_len + 1
            if beam_scorer.is_done or stopping_criteria(input_ids, None):
                break

        final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
        sequence_outputs = beam_scorer.finalize(
            input_ids,
            beam_scores,
            next_tokens,
            next_indices,
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=final_beam_indices,
        )
        return sequence_outputs['sequences']


def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):
    if past:
        input_ids = input_ids[:, -1].unsqueeze(-1)

    attention_mask = kwargs.get("attention_mask", None)
    position_ids = kwargs.get("position_ids", None)

    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
    else:
        position_ids = None
    return {
        "text": input_ids,
        "images": image_inputs,
        "past_key_values": past,
        "position_ids": position_ids,
        "attention_mask": attention_mask,
    }


================================================
FILE: open_clip/constants.py
================================================
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)


================================================
FILE: open_clip/factory.py
================================================
import json
import logging
import os
import pathlib
import re
import numpy as np
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union

import torch

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
    resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize


HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {}  # directory (model_name: config) of model architecture configs


def _natural_key(string_):
    return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]


def _rescan_model_configs():
    global _MODEL_CONFIGS

    config_ext = ('.json',)
    config_files = []
    for config_path in _MODEL_CONFIG_PATHS:
        if config_path.is_file() and config_path.suffix in config_ext:
            config_files.append(config_path)
        elif config_path.is_dir():
            for ext in config_ext:
                config_files.extend(config_path.glob(f'*{ext}'))

    for cf in config_files:
        with open(cf, 'r') as f:
            model_cfg = json.load(f)
            if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
                _MODEL_CONFIGS[cf.stem] = model_cfg

    _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}


_rescan_model_configs()  # initial populate of model config registry


def list_models():
    """ enumerate available model architectures based on config files """
    return list(_MODEL_CONFIGS.keys())


def add_model_config(path):
    """ add model config path or file and update registry """
    if not isinstance(path, Path):
        path = Path(path)
    _MODEL_CONFIG_PATHS.append(path)
    _rescan_model_configs()


def get_model_config(model_name):
    if model_name in _MODEL_CONFIGS:
        return deepcopy(_MODEL_CONFIGS[model_name])
    else:
        return None


def get_tokenizer(model_name):
    if model_name.startswith(HF_HUB_PREFIX):
        tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):])
    else:
        config = get_model_config(model_name)
        tokenizer = HFTokenizer(
            config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
    return tokenizer


def load_state_dict(checkpoint_path: str, map_location='cpu'):
    checkpoint = torch.load(checkpoint_path, map_location=map_location)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    if next(iter(state_dict.items()))[0].startswith('module'):
        state_dict = {k[7:]: v for k, v in state_dict.items()}
    return state_dict


def load_checkpoint(model, checkpoint_path, strict=True):
    state_dict = load_state_dict(checkpoint_path)
    # detect old format and make compatible with new format
    if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
        state_dict = convert_to_custom_text_state_dict(state_dict)
    resize_pos_embed(state_dict, model)
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
    return incompatible_keys


def create_model(
        model_name: str,
        img_size: int,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
        require_pretrained: bool = False,
):
    has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
    if has_hf_hub_prefix:
        model_id = model_name[len(HF_HUB_PREFIX):]
        checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
        config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)

        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        pretrained_cfg = config['preprocess_cfg']
        model_cfg = config['model_cfg']
    else:
        model_name = model_name.replace('/', '-')  # for callers using old naming with / in ViT names
        checkpoint_path = None
        pretrained_cfg = {}
        model_cfg = None

    if isinstance(device, str):
        device = torch.device(device)

    if pretrained and pretrained.lower() == 'openai':
        logging.info(f'Loading pretrained {model_name} from OpenAI.')
        model_cfg = model_cfg or get_model_config(model_name)
        if model_cfg['vision_cfg']['image_size'] != img_size:
            model_cfg['vision_cfg']['image_size'] = img_size
            cast_dtype = get_cast_dtype(precision)

            model_pre = load_openai_model(
                model_name,
                precision=precision,
                device=device,
                jit=jit,
                cache_dir=cache_dir,
            )
            state_dict = model_pre.state_dict()

            # to always output dict even if it is clip
            if output_dict and hasattr(model_pre, "output_dict"):
                model_pre.output_dict = True

            model = CLIP(**model_cfg, cast_dtype=cast_dtype)
            ### for resnet
            if not hasattr(model.visual, 'grid_size'):
                model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))
            resize_pos_embed(state_dict, model)
            incompatible_keys = model.load_state_dict(state_dict, strict=True)
            model.to(device=device)
            if precision in ("fp16", "bf16"):
                convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

            # set image / mean metadata from pretrained_cfg if available, or use default
            model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
            model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

            # to always output dict even if it is clip
            if output_dict and hasattr(model, "output_dict"):
                model.output_dict = True

            if jit:
                model = torch.jit.script(model)
        else:
            model = load_openai_model(
                model_name,
                precision=precision,
                device=device,
                jit=jit,
                cache_dir=cache_dir,
            )

            # to always output dict even if it is clip
            if output_dict and hasattr(model, "output_dict"):
                model.output_dict = True
    else:
        model_cfg = model_cfg or get_model_config(model_name)
        model_cfg['vision_cfg']['image_size'] = img_size
        if model_cfg is not None:
            logging.info(f'Loaded {model_name} model config.')
            pass
        else:
            logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
            raise RuntimeError(f'Model config for {model_name} not found.')

        if force_quick_gelu:
            # override for use of QuickGELU on non-OpenAI transformer models
            model_cfg["quick_gelu"] = True

        if force_patch_dropout is not None:
            # override the default patch dropout value
            model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout

        if force_image_size is not None:
            # override model config's image size
            model_cfg["vision_cfg"]["image_size"] = force_image_size

        if pretrained_image:
            if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
                # pretrained weight loading for timm models set via vision_cfg
                model_cfg['vision_cfg']['timm_model_pretrained'] = True
            else:
                assert False, 'pretrained image towers currently only supported for timm models'

        cast_dtype = get_cast_dtype(precision)
        is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
        custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model

        if custom_text:
            if is_hf_model:
                model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
            if "coca" in model_name:
                model = CoCa(**model_cfg, cast_dtype=cast_dtype)
            else:
                model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
        else:
            model = CLIP(**model_cfg, cast_dtype=cast_dtype)

        pretrained_loaded = False
        if pretrained:
            checkpoint_path = ''
            pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
            if pretrained_cfg:
                checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
            elif os.path.exists(pretrained):
                checkpoint_path = pretrained

            if checkpoint_path:
                logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
                load_checkpoint(model, checkpoint_path)
            else:
                error_str = (
                    f'Pretrained weights ({pretrained}) not found for model {model_name}.'
                    f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
                logging.warning(error_str)
                raise RuntimeError(error_str)
            pretrained_loaded = True
        elif has_hf_hub_prefix:
            logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
            load_checkpoint(model, checkpoint_path)
            pretrained_loaded = True

        if require_pretrained and not pretrained_loaded:
            # callers of create_model_from_pretrained always expect pretrained weights
            raise RuntimeError(
                f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')

        model.to(device=device)
        if precision in ("fp16", "bf16"):
            convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)

        # set image / mean metadata from pretrained_cfg if available, or use default
        model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
        model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD

        # to always output dict even if it is clip
        if output_dict and hasattr(model, "output_dict"):
            model.output_dict = True

        if jit:
            model = torch.jit.script(model)

    return model


def create_loss(args):
    if args.distill:
        return DistillClipLoss(
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    elif "coca" in args.model.lower():
        return CoCaLoss(
            caption_loss_weight=args.coca_caption_loss_weight,
            clip_loss_weight=args.coca_contrastive_loss_weight,
            local_loss=args.local_loss,
            gather_with_grad=args.gather_with_grad,
            cache_labels=True,
            rank=args.rank,
            world_size=args.world_size,
            use_horovod=args.horovod,
        )
    return ClipLoss(
        local_loss=args.local_loss,
        gather_with_grad=args.gather_with_grad,
        cache_labels=True,
        rank=args.rank,
        world_size=args.world_size,
        use_horovod=args.horovod,
    )


def create_model_and_transforms(
        model_name: str,
        img_size: int,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_patch_dropout: Optional[float] = None,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        pretrained_image: bool = False,
        pretrained_hf: bool = True,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
        cache_dir: Optional[str] = None,
        output_dict: Optional[bool] = None,
):
    model = create_model(
        model_name,
        img_size,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_patch_dropout=force_patch_dropout,
        force_image_size=force_image_size,
        pretrained_image=pretrained_image,
        pretrained_hf=pretrained_hf,
        cache_dir=cache_dir,
        output_dict=output_dict,
    )

    image_mean = image_mean or getattr(model.visual, 'image_mean', None)
    image_std = image_std or getattr(model.visual, 'image_std', None)
    preprocess_train = image_transform(
        model.visual.image_size,
        is_train=True,
        mean=image_mean,
        std=image_std,
        aug_cfg=aug_cfg,
    )
    preprocess_val = image_transform(
        model.visual.image_size,
        is_train=False,
        mean=image_mean,
        std=image_std,
    )

    return model, preprocess_train, preprocess_val


def create_model_from_pretrained(
        model_name: str,
        pretrained: Optional[str] = None,
        precision: str = 'fp32',
        device: Union[str, torch.device] = 'cpu',
        jit: bool = False,
        force_quick_gelu: bool = False,
        force_custom_text: bool = False,
        force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
        return_transform: bool = True,
        image_mean: Optional[Tuple[float, ...]] = None,
        image_std: Optional[Tuple[float, ...]] = None,
        cache_dir: Optional[str] = None,
):
    model = create_model(
        model_name,
        pretrained,
        precision=precision,
        device=device,
        jit=jit,
        force_quick_gelu=force_quick_gelu,
        force_custom_text=force_custom_text,
        force_image_size=force_image_size,
        cache_dir=cache_dir,
        require_pretrained=True,
    )

    if not return_transform:
        return model

    image_mean = image_mean or getattr(model.visual, 'image_mean', None)
    image_std = image_std or getattr(model.visual, 'image_std', None)
    preprocess = image_transform(
        model.visual.image_size,
        is_train=False,
        mean=image_mean,
        std=image_std,
    )

    return model, preprocess


================================================
FILE: open_clip/generation_utils.py
================================================


================================================
FILE: open_clip/hf_configs.py
================================================
# HF architecture dict:
arch_dict = {
    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
    "roberta": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "hidden_size",
            "heads": "num_attention_heads",
            "layers": "num_hidden_layers",
            "layer_attr": "layer",
            "token_embeddings_attr": "embeddings"
        },
        "pooler": "mean_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
    "xlm-roberta": {
        "config_names": {
            "context_length": "max_position_embeddings",
            "vocab_size": "vocab_size",
            "width": "hidden_size",
            "heads": "num_attention_heads",
            "layers": "num_hidden_layers",
            "layer_attr": "layer",
            "token_embeddings_attr": "embeddings"
        },
        "pooler": "mean_pooler",
    },
    # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
    "mt5": {
        "config_names": {
            # unlimited seqlen
            # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
            # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
            "context_length": "",
            "vocab_size": "vocab_size",
            "width": "d_model",
            "heads": "num_heads",
            "layers": "num_layers",
            "layer_attr": "block",
            "token_embeddings_attr": "embed_tokens"
        },
        "pooler": "mean_pooler",
    },
}


================================================
FILE: open_clip/hf_model.py
================================================
""" huggingface model adapter

Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""

import re

import torch
import torch.nn as nn
from torch import TensorType

try:
    import transformers
    from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig
    from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
        BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
    transformers = None


    class BaseModelOutput:
        pass


    class PretrainedConfig:
        pass

from .hf_configs import arch_dict


# utils
def _camel2snake(s):
    return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()


# TODO: ?last - for gpt-like models
_POOLERS = {}


def register_pooler(cls):
    """Decorator registering pooler class"""
    _POOLERS[_camel2snake(cls.__name__)] = cls
    return cls


@register_pooler
class MeanPooler(nn.Module):
    """Mean pooling"""

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
        return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)


@register_pooler
class MaxPooler(nn.Module):
    """Max pooling"""

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
        return masked_output.max(1).values


@register_pooler
class ClsPooler(nn.Module):
    """CLS token pooling"""

    def __init__(self, use_pooler_output=True):
        super().__init__()
        self.cls_token_position = 0
        self.use_pooler_output = use_pooler_output

    def forward(self, x: BaseModelOutput, attention_mask: TensorType):
        if (self.use_pooler_output and
            isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
            (x.pooler_output is not None)
        ):
            return x.pooler_output

        return x.last_hidden_state[:, self.cls_token_position, :]


class HFTextEncoder(nn.Module):
    """HuggingFace model adapter"""
    output_tokens: torch.jit.Final[bool]

    def __init__(
            self,
            model_name_or_path: str,
            output_dim: int,
            config: PretrainedConfig = None,
            pooler_type: str = None,
            proj: str = None,
            pretrained: bool = True,
            output_tokens: bool = False,
    ):
        super().__init__()
        self.output_tokens = output_tokens
        self.output_dim = output_dim

        # TODO: find better way to get this information
        uses_transformer_pooler = (pooler_type == "cls_pooler")

        if transformers is None:
            raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
        if config is None:
            self.config = AutoConfig.from_pretrained(model_name_or_path)
            create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
                AutoModel.from_config, self.config)
            # TODO: do all model configs have this attribute? PretrainedConfig does so yes??
            if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
                self.transformer = create_func(model_args)
                self.transformer = self.transformer.encoder
            else:
                self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
        else:
            self.config = config
            self.transformer = AutoModel.from_config(config)
        if pooler_type is None:  # get default arch pooler
            pooler_type = (arch_dict[self.config.model_type]["pooler"])
        
        self.pooler = _POOLERS[pooler_type]()

        d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
        if (d_model == output_dim) and (proj is None):  # do we always need a proj?
            self.proj = nn.Identity()
        elif proj == 'linear':
            self.proj = nn.Linear(d_model, output_dim, bias=False)
        elif proj == 'mlp':
            hidden_size = (d_model + output_dim) // 2
            self.proj = nn.Sequential(
                nn.Linear(d_model, hidden_size, bias=False),
                nn.GELU(),
                nn.Linear(hidden_size, output_dim, bias=False),
            )

    def forward(self, x: TensorType):
        attn_mask = (x != self.config.pad_token_id).long()
        out = self.transformer(input_ids=x, attention_mask=attn_mask)
        pooled_out = self.pooler(out, attn_mask)
        projected = self.proj(pooled_out)

        seq_len = out.last_hidden_state.shape[1]
        tokens = (
            out.last_hidden_state[:, torch.arange(seq_len) != self.pooler.cls_token_position, :] 
            if type(self.pooler) == ClsPooler 
            else out.last_hidden_state
        )
        
        if self.output_tokens:
            return projected, tokens
        return projected

    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        if not unlocked_layers:  # full freezing
            for n, p in self.transformer.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
            return

        encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
        layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
        print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
        embeddings = getattr(
            self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
        modules = [embeddings, *layer_list][:-unlocked_layers]
        # freeze layers
        for module in modules:
            for n, p in module.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.gradient_checkpointing_enable()

    def init_parameters(self):
        pass


================================================
FILE: open_clip/loss.py
================================================
import torch
import torch.nn as nn
from torch.nn import functional as F

try:
    import torch.distributed.nn
    from torch import distributed as dist

    has_distributed = True
except ImportError:
    has_distributed = False

try:
    import horovod.torch as hvd
except ImportError:
    hvd = None


def gather_features(
        image_features,
        text_features,
        local_loss=False,
        gather_with_grad=False,
        rank=0,
        world_size=1,
        use_horovod=False
):
    assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
    if use_horovod:
        assert hvd is not None, 'Please install horovod'
        if gather_with_grad:
            all_image_features = hvd.allgather(image_features)
            all_text_features = hvd.allgather(text_features)
        else:
            with torch.no_grad():
                all_image_features = hvd.allgather(image_features)
                all_text_features = hvd.allgather(text_features)
            if not local_loss:
                # ensure grads for local rank when all_* features don't have a gradient
                gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
                gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
                gathered_image_features[rank] = image_features
                gathered_text_features[rank] = text_features
                all_image_features = torch.cat(gathered_image_features, dim=0)
                all_text_features = torch.cat(gathered_text_features, dim=0)
    else:
        # We gather tensors from all gpus
        if gather_with_grad:
            all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
            all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
        else:
            gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
            gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
            dist.all_gather(gathered_image_features, image_features)
            dist.all_gather(gathered_text_features, text_features)
            if not local_loss:
                # ensure grads for local rank when all_* features don't have a gradient
                gathered_image_features[rank] = image_features
                gathered_text_features[rank] = text_features
            all_image_features = torch.cat(gathered_image_features, dim=0)
            all_text_features = torch.cat(gathered_text_features, dim=0)

    return all_image_features, all_text_features


class ClipLoss(nn.Module):

    def __init__(
            self,
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__()
        self.local_loss = local_loss
        self.gather_with_grad = gather_with_grad
        self.cache_labels = cache_labels
        self.rank = rank
        self.world_size = world_size
        self.use_horovod = use_horovod

        # cache state
        self.prev_num_logits = 0
        self.labels = {}

    def get_ground_truth(self, device, num_logits) -> torch.Tensor:
        # calculated ground-truth and cache if enabled
        if self.prev_num_logits != num_logits or device not in self.labels:
            labels = torch.arange(num_logits, device=device, dtype=torch.long)
            if self.world_size > 1 and self.local_loss:
                labels = labels + num_logits * self.rank
            if self.cache_labels:
                self.labels[device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[device]
        return labels

    def get_logits(self, image_features, text_features, logit_scale):
        if self.world_size > 1:
            all_image_features, all_text_features = gather_features(
                image_features, text_features,
                self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)

            if self.local_loss:
                logits_per_image = logit_scale * image_features @ all_text_features.T
                logits_per_text = logit_scale * text_features @ all_image_features.T
            else:
                logits_per_image = logit_scale * all_image_features @ all_text_features.T
                logits_per_text = logits_per_image.T
        else:
            logits_per_image = logit_scale * image_features @ text_features.T
            logits_per_text = logit_scale * text_features @ image_features.T
        
        return logits_per_image, logits_per_text

    def forward(self, image_features, text_features, logit_scale, output_dict=False):
        device = image_features.device
        logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)

        labels = self.get_ground_truth(device, logits_per_image.shape[0])

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        return {"contrastive_loss": total_loss} if output_dict else total_loss


class CoCaLoss(ClipLoss):
    def __init__(
            self,
            caption_loss_weight,
            clip_loss_weight,
            pad_id=0,  # pad_token for open_clip custom tokenizer
            local_loss=False,
            gather_with_grad=False,
            cache_labels=False,
            rank=0,
            world_size=1,
            use_horovod=False,
    ):
        super().__init__(
            local_loss=local_loss,
            gather_with_grad=gather_with_grad,
            cache_labels=cache_labels,
            rank=rank,
            world_size=world_size,
            use_horovod=use_horovod
        )

        self.clip_loss_weight = clip_loss_weight
        self.caption_loss_weight = caption_loss_weight
        self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id)

    def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False):
        clip_loss = super().forward(image_features, text_features, logit_scale)
        clip_loss = self.clip_loss_weight * clip_loss

        caption_loss = self.caption_loss(
            logits.permute(0, 2, 1),
            labels,
        )
        caption_loss = caption_loss * self.caption_loss_weight

        if output_dict:
            return {"contrastive_loss": clip_loss, "caption_loss": caption_loss}

        return clip_loss, caption_loss


class DistillClipLoss(ClipLoss):

    def dist_loss(self, teacher_logits, student_logits):
        return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0)

    def forward(
            self,
            image_features,
            text_features,
            logit_scale,
            dist_image_features,
            dist_text_features,
            dist_logit_scale,
            output_dict=False,
    ):
        logits_per_image, logits_per_text = \
            self.get_logits(image_features, text_features, logit_scale)

        dist_logits_per_image, dist_logits_per_text = \
            self.get_logits(dist_image_features, dist_text_features, dist_logit_scale)

        labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0])

        contrastive_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2

        distill_loss = (
            self.dist_loss(dist_logits_per_image, logits_per_image) +
            self.dist_loss(dist_logits_per_text, logits_per_text)
        ) / 2

        if output_dict:
            return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss}

        return contrastive_loss, distill_loss


================================================
FILE: open_clip/model.py
================================================
""" CLIP Model

Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
from dataclasses import dataclass
import logging
import math
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint

from .hf_model import HFTextEncoder
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
from .utils import to_2tuple


@dataclass
class CLIPVisionCfg:
    layers: Union[Tuple[int, int, int, int], int] = 12
    width: int = 768
    head_width: int = 64
    mlp_ratio: float = 4.0
    patch_size: int = 16
    image_size: Union[Tuple[int, int], int] = 224
    ls_init_value: Optional[float] = None  # layer scale initial value
    patch_dropout: float = 0.  # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
    input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
    global_average_pool: bool = False  # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
    attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
    n_queries: int = 256 # n_queries for attentional pooler
    attn_pooler_heads: int = 8 # n heads for attentional_pooling
    timm_model_name: str = None  # a valid model name overrides layers, width, patch_size
    timm_model_pretrained: bool = False  # use (imagenet) pretrained weights for named model
    timm_pool: str = 'avg'  # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
    timm_proj: str = 'linear'  # linear projection for timm model output ('linear', 'mlp', '')
    timm_proj_bias: bool = False  # enable bias final projection
    timm_drop: float = 0.  # head dropout
    timm_drop_path: Optional[float] = None  # backbone stochastic depth
    output_tokens: bool = False


@dataclass
class CLIPTextCfg:
    context_length: int = 77
    vocab_size: int = 49408
    width: int = 512
    heads: int = 8
    layers: int = 12
    ls_init_value: Optional[float] = None  # layer scale initial value
    hf_model_name: str = None
    hf_tokenizer_name: str = None
    hf_model_pretrained: bool = True
    proj: str = 'mlp'
    pooler_type: str = 'mean_pooler'
    embed_cls: bool = False
    pad_id: int = 0
    output_tokens: bool = False


def get_cast_dtype(precision: str):
    cast_dtype = None
    if precision == 'bf16':
        cast_dtype = torch.bfloat16
    elif precision == 'fp16':
        cast_dtype = torch.float16
    return cast_dtype


def _build_vision_tower(
        embed_dim: int,
        vision_cfg: CLIPVisionCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None
):
    if isinstance(vision_cfg, dict):
        vision_cfg = CLIPVisionCfg(**vision_cfg)

    # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
    # memory efficient in recent PyTorch releases (>= 1.10).
    # NOTE: timm models always use native GELU regardless of quick_gelu flag.
    act_layer = QuickGELU if quick_gelu else nn.GELU

    if vision_cfg.timm_model_name:
        visual = TimmModel(
            vision_cfg.timm_model_name,
            pretrained=vision_cfg.timm_model_pretrained,
            pool=vision_cfg.timm_pool,
            proj=vision_cfg.timm_proj,
            proj_bias=vision_cfg.timm_proj_bias,
            drop=vision_cfg.timm_drop,
            drop_path=vision_cfg.timm_drop_path,
            embed_dim=embed_dim,
            image_size=vision_cfg.image_size,
        )
        act_layer = nn.GELU  # so that text transformer doesn't use QuickGELU w/ timm models
    elif isinstance(vision_cfg.layers, (tuple, list)):
        vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
        visual = ModifiedResNet(
            layers=vision_cfg.layers,
            output_dim=embed_dim,
            heads=vision_heads,
            image_size=vision_cfg.image_size,
            width=vision_cfg.width,
        )
    else:
        vision_heads = vision_cfg.width // vision_cfg.head_width
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
        visual = VisionTransformer(
            image_size=vision_cfg.image_size,
            patch_size=vision_cfg.patch_size,
            width=vision_cfg.width,
            layers=vision_cfg.layers,
            heads=vision_heads,
            mlp_ratio=vision_cfg.mlp_ratio,
            ls_init_value=vision_cfg.ls_init_value,
            patch_dropout=vision_cfg.patch_dropout,
            input_patchnorm=vision_cfg.input_patchnorm,
            global_average_pool=vision_cfg.global_average_pool,
            attentional_pool=vision_cfg.attentional_pool,
            n_queries=vision_cfg.n_queries,
            attn_pooler_heads=vision_cfg.attn_pooler_heads,
            output_tokens=vision_cfg.output_tokens,
            output_dim=embed_dim,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

    return visual


def _build_text_tower(
        embed_dim: int,
        text_cfg: CLIPTextCfg,
        quick_gelu: bool = False,
        cast_dtype: Optional[torch.dtype] = None,
):
    if isinstance(text_cfg, dict):
        text_cfg = CLIPTextCfg(**text_cfg)

    if text_cfg.hf_model_name:
        text = HFTextEncoder(
            text_cfg.hf_model_name,
            output_dim=embed_dim,
            proj=text_cfg.proj,
            pooler_type=text_cfg.pooler_type,
            pretrained=text_cfg.hf_model_pretrained,
            output_tokens=text_cfg.output_tokens,
        )
    else:
        act_layer = QuickGELU if quick_gelu else nn.GELU
        norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm

        text = TextTransformer(
            context_length=text_cfg.context_length,
            vocab_size=text_cfg.vocab_size,
            width=text_cfg.width,
            heads=text_cfg.heads,
            layers=text_cfg.layers,
            ls_init_value=text_cfg.ls_init_value,
            output_dim=embed_dim,
            embed_cls=text_cfg.embed_cls,
            output_tokens=text_cfg.output_tokens,
            pad_id=text_cfg.pad_id,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
    return text


class CLIP(nn.Module):
    output_dict: torch.jit.Final[bool]

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict
        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)

        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.transformer = text.transformer
        self.vocab_size = text.vocab_size
        self.token_embedding = text.token_embedding
        self.positional_embedding = text.positional_embedding
        self.ln_final = text.ln_final
        self.text_projection = text.text_projection
        self.register_buffer('attn_mask', text.attn_mask, persistent=False)

        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.transformer.grad_checkpointing = enable

    def encode_image(self, image, out_layers, normalize: bool = False):
        features = self.visual(image, out_layers)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        cast_dtype = self.transformer.get_cast_dtype()

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def forward(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        if self.output_dict:
            return {
                "image_features": image_features,
                "text_features": text_features,
                "logit_scale": self.logit_scale.exp()
            }
        return image_features, text_features, self.logit_scale.exp()


class CustomTextCLIP(nn.Module):
    output_dict: torch.jit.Final[bool]

    def __init__(
            self,
            embed_dim: int,
            vision_cfg: CLIPVisionCfg,
            text_cfg: CLIPTextCfg,
            quick_gelu: bool = False,
            cast_dtype: Optional[torch.dtype] = None,
            output_dict: bool = False,
    ):
        super().__init__()
        self.output_dict = output_dict
        self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
        # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
        self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

    def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        self.text.lock(unlocked_layers, freeze_layer_norm)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.visual.set_grad_checkpointing(enable)
        self.text.set_grad_checkpointing(enable)

    def encode_image(self, image, normalize: bool = False):
        features = self.visual(image)
        return F.normalize(features, dim=-1) if normalize else features

    def encode_text(self, text, normalize: bool = False):
        features = self.text(text)
        return F.normalize(features, dim=-1) if normalize else features

    def forward(self, image, text):
        image_features = self.encode_image(image, normalize=True)
        text_features = self.encode_text(text, normalize=True)
        if self.output_dict:
            return {
                "image_features": image_features,
                "text_features": text_features,
                "logit_scale": self.logit_scale.exp()
            }
        return image_features, text_features, self.logit_scale.exp()


def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
    """Convert applicable model parameters to low-precision (bf16 or fp16)"""

    def _convert_weights(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.to(dtype)
            if l.bias is not None:
                l.bias.data = l.bias.data.to(dtype)

        if isinstance(l, (nn.MultiheadAttention, Attention)):
            for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
                tensor = getattr(l, attr)
                if tensor is not None:
                    tensor.data = tensor.data.to(dtype)

        for name in ["text_projection", "proj"]:
            if hasattr(l, name):
                attr = getattr(l, name)
                if attr is not None:
                    attr.data = attr.data.to(dtype)

    model.apply(_convert_weights)


convert_weights_to_fp16 = convert_weights_to_lp  # backwards compat


# used to maintain checkpoint compatibility
def convert_to_custom_text_state_dict(state_dict: dict):
    if 'text_projection' in state_dict:
        # old format state_dict, move text tower -> .text
        new_state_dict = {}
        for k, v in state_dict.items():
            if any(k.startswith(p) for p in (
                'text_projection',
                'positional_embedding',
                'token_embedding',
                'transformer',
                'ln_final',
            )):
                k = 'text.' + k
            new_state_dict[k] = v
        return new_state_dict
    return state_dict


def build_model_from_openai_state_dict(
        state_dict: dict,
        quick_gelu=True,
        cast_dtype=torch.float16,
):
    vit = "visual.proj" in state_dict

    if vit:
        vision_width = state_dict["visual.conv1.weight"].shape[0]
        vision_layers = len(
            [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
        grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
        image_size = vision_patch_size * grid_size
    else:
        counts: list = [
            len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
        vision_layers = tuple(counts)
        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
        output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
        vision_patch_size = None
        assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
        image_size = output_width * 32

    embed_dim = state_dict["text_projection"].shape[1]
    context_length = state_dict["positional_embedding"].shape[0]
    vocab_size = state_dict["token_embedding.weight"].shape[0]
    transformer_width = state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))

    vision_cfg = CLIPVisionCfg(
        layers=vision_layers,
        width=vision_width,
        patch_size=vision_patch_size,
        image_size=image_size,
    )
    text_cfg = CLIPTextCfg(
        context_length=context_length,
        vocab_size=vocab_size,
        width=transformer_width,
        heads=transformer_heads,
        layers=transformer_layers,
    )
    model = CLIP(
        embed_dim,
        vision_cfg=vision_cfg,
        text_cfg=text_cfg,
        quick_gelu=quick_gelu,  # OpenAI models were trained with QuickGELU
        cast_dtype=cast_dtype,
    )

    for key in ["input_resolution", "context_length", "vocab_size"]:
        state_dict.pop(key, None)

    convert_weights_to_fp16(model)  # OpenAI state dicts are partially converted to float16
    model.load_state_dict(state_dict)
    return model.eval()


def trace_model(model, batch_size=256, device=torch.device('cpu')):
    model.eval()
    image_size = model.visual.image_size
    example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
    example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
    model = torch.jit.trace_module(
        model,
        inputs=dict(
            forward=(example_images, example_text),
            encode_text=(example_text,),
            encode_image=(example_images,)
        ))
    model.visual.image_size = image_size
    return model


def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
    # Rescale the grid of position embeddings when loading from state_dict
    flag = 1
    old_pos_embed = state_dict.get('visual.positional_embedding', None)
    if old_pos_embed is None:
        flag = 0
        old_pos_embed = state_dict.get('visual.attnpool.positional_embedding', None)
    if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
        return
    grid_size = to_2tuple(model.visual.grid_size)
    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
    if new_seq_len == old_pos_embed.shape[0]:
        return

    if extra_tokens:
        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
    else:
        pos_emb_tok, pos_emb_img = None, old_pos_embed
    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))

    logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
    pos_emb_img = F.interpolate(
        pos_emb_img,
        size=grid_size,
        mode=interpolation,
        antialias=antialias,
        align_corners=False,
    )
    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
    if pos_emb_tok is not None:
        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
    else:
        new_pos_embed = pos_emb_img
    if flag:
        state_dict['visual.positional_embedding'] = new_pos_embed
    else:
        state_dict['visual.attnpool.positional_embedding'] = new_pos_embed


================================================
FILE: open_clip/model_configs/RN101-quickgelu.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            23,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/RN101.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            23,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/RN50-quickgelu.json
================================================
{
    "embed_dim": 1024,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            6,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}


================================================
FILE: open_clip/model_configs/RN50.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": [
            3,
            4,
            6,
            3
        ],
        "width": 64,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/RN50x16.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 384,
        "layers": [
            6,
            8,
            18,
            8
        ],
        "width": 96,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/RN50x4.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 288,
        "layers": [
            4,
            6,
            10,
            6
        ],
        "width": 80,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/RN50x64.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 448,
        "layers": [
            3,
            15,
            36,
            10
        ],
        "width": 128,
        "patch_size": null
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-16-plus-240.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 240,
        "layers": 12,
        "width": 896,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-16-plus.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 896,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-16.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-32-plus-256.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "image_size": 256,
        "layers": 12,
        "width": 896,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-32-quickgelu.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: open_clip/model_configs/ViT-H-16.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: open_clip/model_configs/ViT-L-14-280.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 280,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-L-14-336.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 336,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-L-14.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-L-16-320.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 320,
        "layers": 24,
        "width": 1024,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-L-16.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-M-16-alt.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 16,
        "ls_init_value": 1e-4
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-M-16.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-M-32-alt.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-M-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 512,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-S-16-alt.json
================================================
{
    "embed_dim": 256,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 256,
        "heads": 4,
        "layers": 10
    }
}

================================================
FILE: open_clip/model_configs/ViT-S-16.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 16
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-S-32-alt.json
================================================
{
    "embed_dim": 256,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 256,
        "heads": 4,
        "layers": 10
    }
}

================================================
FILE: open_clip/model_configs/ViT-S-32.json
================================================
{
    "embed_dim": 384,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 384,
        "patch_size": 32
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 384,
        "heads": 6,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/ViT-bigG-14.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 224,
        "layers": 48,
        "width": 1664,
        "head_width": 104,
        "mlp_ratio": 4.9231,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1280,
        "heads": 20,
        "layers": 32
    }
}

================================================
FILE: open_clip/model_configs/ViT-e-14.json
================================================
{
    "embed_dim": 1280,
    "vision_cfg": {
        "image_size": 224,
        "layers": 56,
        "width": 1792,
        "head_width": 112,
        "mlp_ratio": 8.5715,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1280,
        "heads": 20,
        "layers": 36
    }
}

================================================
FILE: open_clip/model_configs/ViT-g-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 40,
        "width": 1408,
        "head_width": 88,
        "mlp_ratio": 4.3637,
        "patch_size": 14
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: open_clip/model_configs/coca_ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32,
        "attentional_pool": true,
        "attn_pooler_heads": 8,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "embed_cls": true,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12,
        "attn_pooler_heads": 8
    },
    "custom_text": true
}

================================================
FILE: open_clip/model_configs/coca_ViT-L-14.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "image_size": 224,
        "layers": 24,
        "width": 1024,
        "patch_size": 14,
        "attentional_pool": true,
        "attn_pooler_heads": 8,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12,
        "embed_cls": true,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12,
        "attn_pooler_heads": 12
    },
    "custom_text": true
}


================================================
FILE: open_clip/model_configs/coca_base.json
================================================
{
    "embed_dim": 512,
    "multimodal_cfg": {
        "width": 768,
        "context_length": 76,
        "vocab_size": 64000,
        "mlp_ratio": 4,
        "layers": 12,
        "dim_head": 64,
        "heads": 12,
        "n_queries": 256,
        "attn_pooler_heads": 8
    },
    "vision_cfg": {
        "image_size": 288,
        "layers": 12,
        "width": 768,
        "patch_size": 18,
        "output_tokens": true
    },
    "text_cfg": {
        "context_length": 76,
        "vocab_size": 64000,
        "layers": 12,
        "heads": 12,
        "width": 768,
        "embed_cls": true,
        "output_tokens": true
    },
    "custom_text": true
}

================================================
FILE: open_clip/model_configs/coca_roberta-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32,
        "output_tokens": true
    },
    "text_cfg": {
        "hf_model_name": "roberta-base",
        "hf_tokenizer_name": "roberta-base",
        "proj": "linear",
        "width": 768,
        "output_tokens": true
    },
    "multimodal_cfg": {
        "context_length": 76,
        "width": 768,
        "heads": 8,
        "layers": 12
    },
    "custom_text": true
}


================================================
FILE: open_clip/model_configs/convnext_base.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_base_w.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_base_w_320.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "convnext_base",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_large.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_large_d.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "mlp",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 16
    }
}

================================================
FILE: open_clip/model_configs/convnext_large_d_320.json
================================================
{
    "embed_dim": 768,
    "vision_cfg": {
        "timm_model_name": "convnext_large",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "mlp",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 768,
        "heads": 12,
        "layers": 16
    }
}

================================================
FILE: open_clip/model_configs/convnext_small.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "convnext_small",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_tiny.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_tiny",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/convnext_xlarge.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 20
    }
}

================================================
FILE: open_clip/model_configs/convnext_xxlarge.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xxlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: open_clip/model_configs/convnext_xxlarge_320.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "timm_model_name": "convnext_xxlarge",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "timm_drop": 0.0,
        "timm_drop_path": 0.1,
        "image_size": 320
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 1024,
        "heads": 16,
        "layers": 24
    }
}

================================================
FILE: open_clip/model_configs/mt5-base-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "google/mt5-base",
        "hf_tokenizer_name": "google/mt5-base",
        "proj": "mlp",
        "pooler_type": "mean_pooler"
    }
}


================================================
FILE: open_clip/model_configs/mt5-xl-ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "hf_model_name": "google/mt5-xl",
        "hf_tokenizer_name": "google/mt5-xl",
        "proj": "mlp",
        "pooler_type": "mean_pooler"
    }
}


================================================
FILE: open_clip/model_configs/roberta-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "quick_gelu": true,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "roberta-base",
        "hf_tokenizer_name": "roberta-base",
        "proj": "mlp",
        "pooler_type": "mean_pooler"
    }
}


================================================
FILE: open_clip/model_configs/swin_base_patch4_window7_224.json
================================================
{
    "embed_dim": 640,
    "vision_cfg": {
        "timm_model_name": "swin_base_patch4_window7_224",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 640,
        "heads": 10,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/vit_medium_patch16_gap_256.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "vit_medium_patch16_gap_256",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 256
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "timm_model_name": "vit_relpos_medium_patch16_cls_224",
        "timm_model_pretrained": false,
        "timm_pool": "",
        "timm_proj": "linear",
        "image_size": 224
    },
    "text_cfg": {
        "context_length": 77,
        "vocab_size": 49408,
        "width": 512,
        "heads": 8,
        "layers": 12
    }
}

================================================
FILE: open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
================================================
{
    "embed_dim": 512,
    "vision_cfg": {
        "image_size": 224,
        "layers": 12,
        "width": 768,
        "patch_size": 32
    },
    "text_cfg": {
        "hf_model_name": "xlm-roberta-base",
        "hf_tokenizer_name": "xlm-roberta-base",
        "proj": "mlp",
        "pooler_type": "mean_pooler"
    }
}


================================================
FILE: open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
================================================
{
    "embed_dim": 1024,
    "vision_cfg": {
        "image_size": 224,
        "layers": 32,
        "width": 1280,
        "head_width": 80,
        "patch_size": 14
    },
    "text_cfg": {
        "hf_model_name": "xlm-roberta-large",
        "hf_tokenizer_name": "xlm-roberta-large",
        "proj": "mlp",
        "pooler_type": "mean_pooler"
    }
}


================================================
FILE: open_clip/modified_resnet.py
================================================
from collections import OrderedDict

import torch
from torch import nn
from torch.nn import functional as F

from open_clip.utils import freeze_batch_norm_2d


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.act1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.act2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.act3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.act3(out)
        return out


class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, _ = F.multi_head_attention_forward(
            query=x, key=x, value=x,
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0.,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=False
        )

        return x[0]


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, image_size=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.image_size = image_size

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.act2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.act3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)

        self.init_parameters()

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def init_parameters(self):
        if self.attnpool is not None:
            std = self.attnpool.c_proj.in_features ** -0.5
            nn.init.normal_(self.attnpool.q_proj.weight, std=std)
            nn.init.normal_(self.attnpool.k_proj.weight, std=std)
            nn.init.normal_(self.attnpool.v_proj.weight, std=std)
            nn.init.normal_(self.attnpool.c_proj.weight, std=std)

        for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for name, param in resnet_block.named_parameters():
                if name.endswith("bn3.weight"):
                    nn.init.zeros_(param)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        assert unlocked_groups == 0, 'partial locking not currently supported for this model'
        for param in self.parameters():
            param.requires_grad = False
        if freeze_bn_stats:
            freeze_batch_norm_2d(self)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        # FIXME support for non-transformer
        pass

    def stem(self, x):
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x)))
        x = self.act3(self.bn3(self.conv3(x)))
        x = self.avgpool(x)
        return x

    def forward(self, x, out_blocks):
        x = self.stem(x)
        x_1 = self.layer1(x)
        x_2 = self.layer2(x_1)
        x_3 = self.layer3(x_2)
        x_4 = self.layer4(x_3)
        x = self.attnpool(x_4)

        out_tokens = []
        x_blocks = [x_1, x_2, x_3, x_4]
        for i in out_blocks:
            out_tokens.append(x_blocks[i - 1])

        return x, out_tokens


================================================
FILE: open_clip/openai.py
================================================
""" OpenAI pretrained model functions

Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""

import os
import warnings
from typing import List, Optional, Union

import torch

from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url

__all__ = ["list_openai_models", "load_openai_model"]


def list_openai_models() -> List[str]:
    """Returns the names of available CLIP models"""
    return list_pretrained_models_by_tag('openai')


def load_openai_model(
        name: str,
        precision: Optional[str] = None,
        device: Optional[Union[str, torch.device]] = None,
        jit: bool = True,
        cache_dir: Optional[str] = None,
):
    """Load a CLIP model

    Parameters
    ----------
    name : str
        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
    precision: str
        Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
    device : Union[str, torch.device]
        The device to put the loaded model
    jit : bool
        Whether to load the optimized JIT model (default) or more hackable non-JIT model.
    cache_dir : Optional[str]
        The directory to cache the downloaded model weights

    Returns
    -------
    model : torch.nn.Module
        The CLIP model
    preprocess : Callable[[PIL.Image], torch.Tensor]
        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if precision is None:
        precision = 'fp32' if device == 'cpu' else 'fp16'

    if get_pretrained_url(name, 'openai'):
        model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
    elif os.path.isfile(name):
        model_path = name
    else:
        raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
        state_dict = None
    except RuntimeError:
        # loading saved state dict
        if jit:
            warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
            jit = False
        state_dict = torch.load(model_path, map_location="cpu")

    if not jit:
        # Build a non-jit model from the OpenAI jitted model state dict
        cast_dtype = get_cast_dtype(precision)
        try:
            model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
        except KeyError:
            sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
            model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)

        # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
        model = model.to(device)
        if precision.startswith('amp') or precision == 'fp32':
            model.float()
        elif precision == 'bf16':
            convert_weights_to_lp(model, dtype=torch.bfloat16)

        return model

    # patch the device names
    device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
    device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

    def patch_device(module):
        try:
            graphs = [module.graph] if hasattr(module, "graph") else []
        except RuntimeError:
            graphs = []

        if hasattr(module, "forward1"):
            graphs.append(module.forward1.graph)

        for graph in graphs:
            for node in graph.findAllNodes("prim::Constant"):
                if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
                    node.copyAttributes(device_node)

    model.apply(patch_device)
    patch_device(model.encode_image)
    patch_device(model.encode_text)

    # patch dtype to float32 (typically for CPU)
    if precision == 'fp32':
        float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
        float_node = float_input.node()

        def patch_float(module):
            try:
                graphs = [module.graph] if hasattr(module, "graph") else []
            except RuntimeError:
                graphs = []

            if hasattr(module, "forward1"):
                graphs.append(module.forward1.graph)

            for graph in graphs:
                for node in graph.findAllNodes("aten::to"):
                    inputs = list(node.inputs())
                    for i in [1, 2]:  # dtype can be the second or third argument to aten::to()
                        if inputs[i].node()["value"] == 5:
                            inputs[i].node().copyAttributes(float_node)

        model.apply(patch_float)
        patch_float(model.encode_image)
        patch_float(model.encode_text)
        model.float()

    # ensure image_size attr available at consistent location for both jit and non-jit
    model.visual.image_size = model.input_resolution.item()
    return model


================================================
FILE: open_clip/pretrained.py
================================================
import hashlib
import os
import urllib
import warnings
from functools import partial
from typing import Dict, Union

from tqdm import tqdm

from .version import __version__

try:
    from huggingface_hub import hf_hub_download
    hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__)
    _has_hf_hub = True
except ImportError:
    hf_hub_download = None
    _has_hf_hub = False


def _pcfg(url='', hf_hub='', mean=None, std=None):
    return dict(
        url=url,
        hf_hub=hf_hub,
        mean=mean,
        std=std,
    )


_RN50 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
    yfcc15m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
    cc12m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)

_RN50_quickgelu = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"),
    yfcc15m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"),
    cc12m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"),
)

_RN101 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
    yfcc15m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)

_RN101_quickgelu = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"),
    yfcc15m=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"),
)

_RN50x4 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"),
)

_RN50x16 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"),
)

_RN50x64 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"),
)

_VITB32 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
    laion400m_e31=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
    laion400m_e32=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
    laion2b_e16=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
    laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
)

_VITB32_quickgelu = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
    laion400m_e31=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
    laion400m_e32=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
)

_VITB16 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
    laion400m_e31=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
    laion400m_e32=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
    # laion400m_32k=_pcfg(
    #     url="",
    #     mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    # laion400m_64k=_pcfg(
    #     url="",
    #     mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
)

_VITB16_PLUS_240 = dict(
    laion400m_e31=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
    laion400m_e32=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
)

_VITL14 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
    laion400m_e31=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
    laion400m_e32=_pcfg(
        "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
    laion2b_s32b_b82k=_pcfg(
        hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
        mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
)

_VITL14_336 = dict(
    openai=_pcfg(
        "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
)

_VITH14 = dict(
    laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
)

_VITg14 = dict(
    laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
    laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
)

_VITbigG14 = dict(
    laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
)

_robertaViTB32 = dict(
    laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),
)

_xlmRobertaBaseViTB32 = dict(
    laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),
)

_xlmRobertaLargeFrozenViTH14 = dict(
    frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),
)

_convnext_base = dict(
    laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),
)

_convnext_base_w = dict(
    laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),
    laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),
    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),
)

_convnext_base_w_320 = dict(
    laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),
    laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),
)

_convnext_large_d = dict(
    laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),
)

_convnext_large_d_320 = dict(
    laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),
    laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),
)

_convnext_xxlarge = dict(
    laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),
    laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),
    laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),
)

_coca_VITB32 = dict(
    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),
    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')
)

_coca_VITL14 = dict(
    laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),
    mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')
)


_PRETRAINED = {
    "RN50": _RN50,
    "RN50-quickgelu": _RN50_quickgelu,
    "RN101": _RN101,
    "RN101-quickgelu": _RN101_quickgelu,
    "RN50x4": _RN50x4,
    "RN50x16": _RN50x16,
    "RN50x64": _RN50x64,
    "ViT-B-32": _VITB32,
    "ViT-B-32-quickgelu": _VITB32_quickgelu,
    "ViT-B-16": _VITB16,
    "ViT-B-16-plus-240": _VITB16_PLUS_240,
    "ViT-L-14": _VITL14,
    "ViT-L-14-336": _VITL14_336,
    "ViT-H-14": _VITH14,
    "ViT-g-14": _VITg14,
    "ViT-bigG-14": _VITbigG14,
    "roberta-ViT-B-32": _robertaViTB32,
    "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32,
    "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14,
    "convnext_base": _convnext_base,
    "convnext_base_w": _convnext_base_w,
    "convnext_base_w_320": _convnext_base_w_320,
    "convnext_large_d": _convnext_large_d,
    "convnext_large_d_320": _convnext_large_d_320,
    "convnext_xxlarge": _convnext_xxlarge,
    "coca_ViT-B-32": _coca_VITB32,
    "coca_ViT-L-14": _coca_VITL14,
}


def _clean_tag(tag: str):
    # normalize pretrained tags
    return tag.lower().replace('-', '_')


def list_pretrained(as_str: bool = False):
    """ returns list of pretrained models
    Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
    """
    return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]


def list_pretrained_models_by_tag(tag: str):
    """ return all models having the specified pretrain tag """
    models = []
    tag = _clean_tag(tag)
    for k in _PRETRAINED.keys():
        if tag in _PRETRAINED[k]:
            models.append(k)
    return models


def list_pretrained_tags_by_model(model: str):
    """ return all pretrain tags for the specified model architecture """
    tags = []
    if model in _PRETRAINED:
        tags.extend(_PRETRAINED[model].keys())
    return tags


def is_pretrained_cfg(model: str, tag: str):
    if model not in _PRETRAINED:
        return False
    return _clean_tag(tag) in _PRETRAINED[model]


def get_pretrained_cfg(model: str, tag: str):
    if model not in _PRETRAINED:
        return {}
    model_pretrained = _PRETRAINED[model]
    return model_pretrained.get(_clean_tag(tag), {})


def get_pretrained_url(model: str, tag: str):
    cfg = get_pretrained_cfg(model, _clean_tag(tag))
    return cfg.get('url', '')


def download_pretrained_from_url(
        url: str,
        cache_dir: Union[str, None] = None,
):
    if not cache_dir:
        cache_dir = os.path.expanduser("~/.cache/clip")
    os.makedirs(cache_dir, exist_ok=True)
    filename = os.path.basename(url)

    if 'openaipublic' in url:
        expected_sha256 = url.split("/")[-2]
    elif 'mlfoundations' in url:
        expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
    else:
        expected_sha256 = ''

    download_target = os.path.join(cache_dir, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        if expected_sha256:
            if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
                return download_target
            else:
                warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
        else:
            return download_target

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
        raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")

    return download_target


def has_hf_hub(necessary=False):
    if not _has_hf_hub and necessary:
        # if no HF Hub module installed, and it is necessary to continue, raise error
        raise RuntimeError(
            'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
    return _has_hf_hub


def download_pretrained_from_hf(
        model_id: str,
        filename: str = 'open_clip_pytorch_model.bin',
        revision=None,
        cache_dir: Union[str, None] = None,
):
    has_hf_hub(True)
    cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
    return cached_file


def download_pretrained(
        cfg: Dict,
        force_hf_hub: bool = False,
        cache_dir: Union[str, None] = None,
):
    target = ''
    if not cfg:
        return target

    download_url = cfg.get('url', '')
    download_hf_hub = cfg.get('hf_hub', '')
    if download_hf_hub and force_hf_hub:
        # use HF hub even if url exists
        download_url = ''

    if download_url:
        target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
    elif download_hf_hub:
        has_hf_hub(True)
        # we assume the hf_hub entries in pretrained config combine model_id + filename in
        # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
        # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
        model_id, filename = os.path.split(download_hf_hub)
        if filename:
            target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
        else:
            target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)

    return target


================================================
FILE: open_clip/push_to_hf_hub.py
================================================
import argparse
import json
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Tuple

import torch

try:
    from huggingface_hub import (
        create_repo,
        get_hf_file_metadata,
        hf_hub_download,
        hf_hub_url,
        repo_type_and_id_from_hf_id,
        upload_folder,
    )
    from huggingface_hub.utils import EntryNotFoundError
    _has_hf_hub = True
except ImportError:
    _has_hf_hub = False

from .factory import create_model_from_pretrained, get_model_config, get_tokenizer
from .tokenizer import HFTokenizer


def save_config_for_hf(
        model,
        config_path: str,
        model_config: Optional[dict]
):
    preprocess_cfg = {
        'mean': model.visual.image_mean,
        'std': model.visual.image_std,
    }
    hf_config = {
        'model_cfg': model_config,
        'preprocess_cfg': preprocess_cfg,
    }

    with config_path.open('w') as f:
        json.dump(hf_config, f, indent=2)


def save_for_hf(
    model,
    tokenizer: HFTokenizer,
    model_config: dict,
    save_directory: str,
    weights_filename='open_clip_pytorch_model.bin',
    config_filename='open_clip_config.json',
):
    save_directory = Path(save_directory)
    save_directory.mkdir(exist_ok=True, parents=True)

    weights_path = save_directory / weights_filename
    torch.save(model.state_dict(), weights_path)

    tokenizer.save_pretrained(save_directory)

    config_path = save_directory / config_filename
    save_config_for_hf(model, config_path, model_config=model_config)


def push_to_hf_hub(
    model,
    tokenizer,
    model_config: Optional[dict],
    repo_id: str,
    commit_message: str = 'Add model',
    token: Optional[str] = None,
    revision: Optional[str] = None,
    private: bool = False,
    create_pr: bool = False,
    model_card: Optional[dict] = None,
):
    if not isinstance(tokenizer, HFTokenizer):
        # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14
        tokenizer = HFTokenizer('openai/clip-vit-large-patch14')

    # Create repo if it doesn't exist yet
    repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)

    # Infer complete repo_id from repo_url
    # Can be different from the input `repo_id` if repo_owner was implicit
    _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
    repo_id = f"{repo_owner}/{repo_name}"

    # Check if README file already exist in repo
    try:
        get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
        has_readme = True
    except EntryNotFoundError:
        has_readme = False

    # Dump model and push to Hub
    with TemporaryDirectory() as tmpdir:
        # Save model weights and config.
        save_for_hf(
            model,
            tokenizer=tokenizer,
            model_config=model_config,
            save_directory=tmpdir,
        )

        # Add readme if it does not exist
        if not has_readme:
            model_card = model_card or {}
            model_name = repo_id.split('/')[-1]
            readme_path = Path(tmpdir) / "README.md"
            readme_text = generate_readme(model_card, model_name)
            readme_path.write_text(readme_text)

        # Upload model and return
        return upload_folder(
            repo_id=repo_id,
            folder_path=tmpdir,
            revision=revision,
            create_pr=create_pr,
            commit_message=commit_message,
        )


def push_pretrained_to_hf_hub(
    model_name,
    pretrained: str,
    repo_id: str,
    image_mean: Optional[Tuple[float, ...]] = None,
    image_std: Optional[Tuple[float, ...]] = None,
    commit_message: str = 'Add model',
    token: Optional[str] = None,
    revision: Optional[str] = None,
    private: bool = False,
    create_pr: bool = False,
    model_card: Optional[dict] = None,
):
    model, preprocess_eval = create_model_from_pretrained(
        model_name,
        pretrained=pretrained,
        image_mean=image_mean,
        image_std=image_std,
    )

    model_config = get_model_config(model_name)
    assert model_config

    tokenizer = get_tokenizer(model_name)

    push_to_hf_hub(
        model=model,
        tokenizer=tokenizer,
        model_config=model_config,
        repo_id=repo_id,
        commit_message=commit_message,
        token=token,
        revision=revision,
        private=private,
        create_pr=create_pr,
        model_card=model_card,
    )


def generate_readme(model_card: dict, model_name: str):
    readme_text = "---\n"
    readme_text += "tags:\n- zero-shot-image-classification\n- clip\n"
    readme_text += "library_tag: open_clip\n"
    readme_text += f"license: {model_card.get('license', 'mit')}\n"
    if 'details' in model_card and 'Dataset' in model_card['details']:
        readme_text += 'datasets:\n'
        readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
    readme_text += "---\n"
    readme_text += f"# Model card for {model_name}\n"
    if 'description' in model_card:
        readme_text += f"\n{model_card['description']}\n"
    if 'details' in model_card:
        readme_text += f"\n## Model Details\n"
        for k, v in model_card['details'].items():
            if isinstance(v, (list, tuple)):
                readme_text += f"- **{k}:**\n"
                for vi in v:
                    readme_text += f"  - {vi}\n"
            elif isinstance(v, dict):
                readme_text += f"- **{k}:**\n"
                for ki, vi in v.items():
                    readme_text += f"  - {ki}: {vi}\n"
            else:
                readme_text += f"- **{k}:** {v}\n"
    if 'usage' in model_card:
        readme_text += f"\n## Model Usage\n"
        readme_text += model_card['usage']
        readme_text += '\n'

    if 'comparison' in model_card:
        readme_text += f"\n## Model Comparison\n"
        readme_text += model_card['comparison']
        readme_text += '\n'

    if 'citation' in model_card:
        readme_text += f"\n## Citation\n"
        if not isinstance(model_card['citation'], (list, tuple)):
            citations = [model_card['citation']]
        else:
            citations = model_card['citation']
        for c in citations:
            readme_text += f"```bibtex\n{c}\n```\n"

    return readme_text


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Push to Hugging Face Hub")
    parser.add_argument(
        "--model", type=str, help="Name of the model to use.",
    )
    parser.add_argument(
        "--pretrained", type=str,
        help="Use a pretrained CLIP model weights with the specified tag or file path.",
    )
    parser.add_argument(
        "--repo-id", type=str,
        help="Destination HF Hub repo-id ie 'organization/model_id'.",
    )
    parser.add_argument(
        '--image-mean', type=float, nargs='+', default=None, metavar='MEAN',
        help='Override default image mean value of dataset')
    parser.add_argument(
        '--image-std', type=float, nargs='+', default=None, metavar='STD',
        help='Override default image std deviation of of dataset')
    args = parser.parse_args()

    print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}')

    # FIXME add support to pass model_card json / template from file via cmd line

    push_pretrained_to_hf_hub(
        args.model,
        args.pretrained,
        args.repo_id,
        image_mean=args.image_mean,  # override image mean/std if trained w/ non defaults
        image_std=args.image_std,
    )

    print(f'{args.model} saved.')


================================================
FILE: open_clip/timm_model.py
================================================
""" timm model adapter

Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import logging
from collections import OrderedDict

import torch
import torch.nn as nn

try:
    import timm
    from timm.models.layers import Mlp, to_2tuple
    try:
        # old timm imports < 0.8.1
        from timm.models.layers.attention_pool2d import RotAttentionPool2d
        from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
    except ImportError:
        # new timm imports >= 0.8.1
        from timm.layers import RotAttentionPool2d
        from timm.layers import AttentionPool2d as AbsAttentionPool2d
except ImportError:
    timm = None

from .utils import freeze_batch_norm_2d


class TimmModel(nn.Module):
    """ timm model adapter
    # FIXME this adapter is a work in progress, may change in ways that break weight compat
    """

    def __init__(
            self,
            model_name,
            embed_dim,
            image_size=224,
            pool='avg',
            proj='linear',
            proj_bias=False,
            drop=0.,
            drop_path=None,
            pretrained=False,
    ):
        super().__init__()
        if timm is None:
            raise RuntimeError("Please `pip install timm` to use timm models.")

        self.image_size = to_2tuple(image_size)
        timm_kwargs = {}
        if drop_path is not None:
            timm_kwargs['drop_path_rate'] = drop_path
        self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs)
        feat_size = self.trunk.default_cfg.get('pool_size', None)
        feature_ndim = 1 if not feat_size else 2
        if pool in ('abs_attn', 'rot_attn'):
            assert feature_ndim == 2
            # if attn pooling used, remove both classifier and default pool
            self.trunk.reset_classifier(0, global_pool='')
        else:
            # reset global pool if pool config set, otherwise leave as network default
            reset_kwargs = dict(global_pool=pool) if pool else {}
            self.trunk.reset_classifier(0, **reset_kwargs)
        prev_chs = self.trunk.num_features

        head_layers = OrderedDict()
        if pool == 'abs_attn':
            head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
            prev_chs = embed_dim
        elif pool == 'rot_attn':
            head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
            prev_chs = embed_dim
        else:
            assert proj, 'projection layer needed if non-attention pooling is used.'

        # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
        if proj == 'linear':
            head_layers['drop'] = nn.Dropout(drop)
            head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
        elif proj == 'mlp':
            head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias))

        self.head = nn.Sequential(head_layers)

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        """ lock modules
        Args:
            unlocked_groups (int): leave last n layer groups unlocked (default: 0)
        """
        if not unlocked_groups:
            # lock full model
            for param in self.trunk.parameters():
                param.requires_grad = False
            if freeze_bn_stats:
                freeze_batch_norm_2d(self.trunk)
        else:
            # NOTE: partial freeze requires latest timm (master) branch and is subject to change
            try:
                # FIXME import here until API stable and in an official release
                from timm.models.helpers import group_parameters, group_modules
            except ImportError:
                raise RuntimeError(
                    'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
            matcher = self.trunk.group_matcher()
            gparams = group_parameters(self.trunk, matcher)
            max_layer_id = max(gparams.keys())
            max_layer_id = max_layer_id - unlocked_groups
            for group_idx in range(max_layer_id + 1):
                group = gparams[group_idx]
                for param in group:
                    self.trunk.get_parameter(param).requires_grad = False
            if freeze_bn_stats:
                gmodules = group_modules(self.trunk, matcher, reverse=True)
                gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
                freeze_batch_norm_2d(self.trunk, gmodules)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        try:
            self.trunk.set_grad_checkpointing(enable)
        except Exception as e:
            logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')

    def forward(self, x):
        x = self.trunk(x)
        x = self.head(x)
        return x


================================================
FILE: open_clip/tokenizer.py
================================================
""" CLIP tokenizer

Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import gzip
import html
import os
from functools import lru_cache
from typing import Union, List

import ftfy
import regex as re
import torch

# https://stackoverflow.com/q/62691279
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


@lru_cache()
def default_bpe():
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a significant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        if not special_tokens:
            special_tokens = ['<start_of_text>', '<end_of_text>']
        else:
            special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
        vocab.extend(special_tokens)
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {t:t for t in special_tokens}
        special = "|".join(special_tokens)
        self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

        self.vocab_size = len(self.encoder)
        self.all_special_ids = [self.encoder[t] for t in special_tokens]

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


_tokenizer = SimpleTokenizer()

def decode(output_ids: torch.Tensor):
    output_ids = output_ids.cpu().numpy()
    return _tokenizer.decode(output_ids)

def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
    """
    Returns the tokenized representation of given input string(s)

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize
    context_length : int
        The context length to use; all CLIP models use 77 as the context length

    Returns
    -------
    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
    """
    if isinstance(texts, str):
        texts = [texts]

    sot_token = _tokenizer.encoder["<start_of_text>"]
    eot_token = _tokenizer.encoder["<end_of_text>"]
    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)

    for i, tokens in enumerate(all_tokens):
        if len(tokens) > context_length:
            tokens = tokens[:context_length]  # Truncate
            tokens[-1] = eot_token
        result[i, :len(tokens)] = torch.tensor(tokens)

    return result


class HFTokenizer:
    """HuggingFace tokenizer wrapper"""

    def __init__(self, tokenizer_name: str):
        from transformers import AutoTokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def save_pretrained(self, dest):
        self.tokenizer.save_pretrained(dest)

    def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor:
        # same cleaning as for default tokenizer, except lowercasing
        # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
        if isinstance(texts, str):
            texts = [texts]
        texts = [whitespace_clean(basic_clean(text)) for text in texts]
        input_ids = self.tokenizer(
            texts,
            return_tensors='pt',
            max_length=context_length,
            padding='max_length',
            truncation=True,
        ).input_ids
        return input_ids


================================================
FILE: open_clip/transform.py
================================================
import warnings
from dataclasses import dataclass, asdict
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torchvision.transforms.functional as F

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
    CenterCrop

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD


@dataclass
class AugmentationCfg:
    scale: Tuple[float, float] = (0.9, 1.0)
    ratio: Optional[Tuple[float, float]] = None
    color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
    interpolation: Optional[str] = None
    re_prob: Optional[float] = None
    re_count: Optional[int] = None
    use_timm: bool = False


class ResizeMaxSize(nn.Module):

    def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
        super().__init__()
        if not isinstance(max_size, int):
            raise TypeError(f"Size should be int. Got {type(max_size)}")
        self.max_size = max_size
        self.interpolation = interpolation
        self.fn = min if fn == 'min' else min
        self.fill = fill

    def forward(self, img):
        if isinstance(img, torch.Tensor):
            height, width = img.shape[:2]
        else:
            width, height = img.size
        scale = self.max_size / float(max(height, width))
        if scale != 1.0:
            new_size = tuple(round(dim * scale) for dim in (height, width))
            img = F.resize(img, new_size, self.interpolation)
            pad_h = self.max_size - new_size[0]
            pad_w = self.max_size - new_size[1]
            img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
        return img


def _convert_to_rgb(image):
    return image.convert('RGB')


def image_transform(
        image_size: int,
        is_train: bool,
        mean: Optional[Tuple[float, ...]] = None,
        std: Optional[Tuple[float, ...]] = None,
        resize_longest_max: bool = False,
        fill_color: int = 0,
        aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
):
    mean = mean or OPENAI_DATASET_MEAN
    if not isinstance(mean, (list, tuple)):
        mean = (mean,) * 3

    std = std or OPENAI_DATASET_STD
    if not isinstance(std, (list, tuple)):
        std = (std,) * 3

    if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
        # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
        image_size = image_size[0]

    if isinstance(aug_cfg, dict):
        aug_cfg = AugmentationCfg(**aug_cfg)
    else:
        aug_cfg = aug_cfg or AugmentationCfg()
    normalize = Normalize(mean=mean, std=std)
    if is_train:
        aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
        use_timm = aug_cfg_dict.pop('use_timm', False)
        if use_timm:
            from timm.data import create_transform  # timm can still be optional
            if isinstance(image_size, (tuple, list)):
                assert len(image_size) >= 2
                input_size = (3,) + image_size[-2:]
            else:
                input_size = (3, image_size, image_size)
            # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
            aug_cfg_dict.setdefault('interpolation', 'random')
            aug_cfg_dict.setdefault('color_jitter', None)  # disable by default
            train_transform = create_transform(
                input_size=input_size,
                is_training=True,
                hflip=0.,
                mean=mean,
                std=std,
                re_mode='pixel',
                **aug_cfg_dict,
            )
        else:
            train_transform = Compose([
                RandomResizedCrop(
                    image_size,
                    scale=aug_cfg_dict.pop('scale'),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                _convert_to_rgb,
                ToTensor(),
                normalize,
            ])
            if aug_cfg_dict:
                warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
        return train_transform
    else:
        if resize_longest_max:
            transforms = [
                ResizeMaxSize(image_size, fill=fill_color)
            ]
        else:
            transforms = [
                Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
                CenterCrop((image_size, image_size)),
            ]
        transforms.extend([
            _convert_to_rgb,
            ToTensor(),
            normalize,
        ])
        return Compose(transforms)


================================================
FILE: open_clip/transformer.py
================================================
from collections import OrderedDict
import math
from typing import Callable, Optional, Sequence, Tuple

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint

from .utils import to_2tuple
import numpy as np


class LayerNormFp32(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
        return x.to(orig_type)


class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm (with cast back to input dtype)."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        return x.to(orig_type)


class QuickGELU(nn.Module):
    # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class PatchDropout(nn.Module):
    """
    https://arxiv.org/abs/2212.00794
    """

    def __init__(self, prob, exclude_first_token=True):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob
        self.exclude_first_token = exclude_first_token  # exclude CLS token

    def forward(self, x):
        if not self.training or self.prob == 0.:
            return x

        if self.exclude_first_token:
            cls_tokens, x = x[:, :1], x[:, 1:]
        else:
            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])

        batch = x.size()[0]
        num_tokens = x.size()[1]

        batch_indices = torch.arange(batch)
        batch_indices = batch_indices[..., None]

        keep_prob = 1 - self.prob
        num_patches_keep = max(1, int(num_tokens * keep_prob))

        rand = torch.randn(batch, num_tokens)
        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices

        x = x[batch_indices, patch_indices_keep]

        if self.exclude_first_token:
            x = torch.cat((cls_tokens, x), dim=1)

        return x


class Attention(nn.Module):
    def __init__(
            self,
            dim,
            num_heads=8,
            qkv_bias=True,
            scaled_cosine=False,
            scale_heads=False,
            logit_scale_max=math.log(1. / 0.01),
            attn_drop=0.,
            proj_drop=0.
    ):
        super().__init__()
        self.scaled_cosine = scaled_cosine
        self.scale_heads = scale_heads
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.logit_scale_max = logit_scale_max

        # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
        self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
        if qkv_bias:
            self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
        else:
            self.in_proj_bias = None

        if self.scaled_cosine:
            self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
        else:
            self.logit_scale = None
        self.attn_drop = nn.Dropout(attn_drop)
        if self.scale_heads:
            self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
        else:
            self.head_scale = None
        self.out_proj = nn.Linear(dim, dim)
        self.out_drop = nn.Dropout(proj_drop)

    def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
        L, N, C = x.shape
        q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
        q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
        k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
        v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)

        if self.logit_scale is not None:
            attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
            logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
            attn = attn.view(N, self.num_heads, L, L) * logit_scale
            attn = attn.view(-1, L, L)
        else:
            q = q * self.scale
            attn = torch.bmm(q, k.transpose(-1, -2))

        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
                new_attn_mask.masked_fill_(attn_mask, float("-inf"))
                attn_mask = new_attn_mask
            attn += attn_mask

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = torch.bmm(attn, v)
        if self.head_scale is not None:
            x = x.view(N, self.num_heads, L, C) * self.head_scale
            x = x.view(-1, L, C)
        x = x.transpose(0, 1).reshape(L, N, C)
        x = self.out_proj(x)
        x = self.out_drop(x)
        return x


class AttentionalPooler(nn.Module):
    def __init__(
            self,
            d_model: int,
            context_dim: int,
            n_head: int = 8,
            n_queries: int = 256,
            norm_layer: Callable = LayerNorm
    ):
        super().__init__()
        self.query = nn.Parameter(torch.randn(n_queries, d_model))
        self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim)
        self.ln_q = norm_layer(d_model)
        self.ln_k = norm_layer(context_dim)

    def forward(self, x: torch.Tensor):
        x = self.ln_k(x).permute(1, 0, 2)  # NLD -> LND
        N = x.shape[1]
        q = self.ln_q(self.query)
        out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0]
        return out.permute(1, 0, 2)  # LND -> NLD

    def _repeat(self, query, N: int):
        return query.unsqueeze(1).repeat(1, N, 1)


class ResidualAttentionBlock(nn.Module):
    def __init__(
            self,
            d_model: int,
            n_head: int,
            mlp_ratio: float = 4.0,
            ls_init_value: float = None,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
            is_cross_attention: bool = False,
            idx: int = 12,
    ):
        super().__init__()

        self.idx = idx

        self.ln_1 = norm_layer(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
        if is_cross_attention:
            self.ln_1_kv = norm_layer(d_model)

        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, mlp_width)),
            ("gelu", act_layer()),
            ("c_proj", nn.Linear(mlp_width, d_model))
        ]))
        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

    def attention(
            self,
            q_x: torch.Tensor,
            k_x: Optional[torch.Tensor] = None,
            v_x: Optional[torch.Tensor] = None,
            attn_mask: Optional[torch.Tensor] = None,
    ):
        k_x = k_x if k_x is not None else q_x
        v_x = v_x if v_x is not None else q_x

        attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
        return self.attn(
            q_x, k_x, v_x, need_weights=True, attn_mask=attn_mask
        )

    def forward(
            self,
            q_x: torch.Tensor,
            k_x: Optional[torch.Tensor] = None,
            v_x: Optional[torch.Tensor] = None,
            attn_mask: Optional[torch.Tensor] = None,
    ):
        k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
        v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None

        tmp, attn = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
        x = q_x + self.ls_1(tmp)
        x = x + self.ls_2(self.mlp(self.ln_2(x)))
        return x, attn


class CustomResidualAttentionBlock(nn.Module):
    def __init__(
            self,
            d_model: int,
            n_head: int,
            mlp_ratio: float = 4.0,
            ls_init_value: float = None,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
            scale_cosine_attn: bool = False,
            scale_heads: bool = False,
            scale_attn: bool = False,
            scale_fc: bool = False,
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.attn = Attention(
            d_model, n_head,
            scaled_cosine=scale_cosine_attn,
            scale_heads=scale_heads,
        )
        self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
        self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, mlp_width)),
            ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
            ("gelu", act_layer()),
            ("c_proj", nn.Linear(mlp_width, d_model))
        ]))
        self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
        x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))
        x = x + self.ls_2(self.mlp(self.ln_2(x)))
        return x


class Transformer(nn.Module):
    def __init__(
            self,
            width: int,
            layers: int,
            heads: int,
            mlp_ratio: float = 4.0,
            ls_init_value: float = None,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = False

        self.resblocks = nn.ModuleList([
            ResidualAttentionBlock(
                width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer,
                idx=idx)
            for idx in range(layers)
        ])

    def get_cast_dtype(self) -> torch.dtype:
        return self.resblocks[0].mlp.c_fc.weight.dtype

    def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],
                attn_mask: Optional[torch.Tensor] = None):
        idx = 0
        out_attn = []
        # out_tokens = x
        out_tokens = []
        for r in self.resblocks:
            idx += 1
            if self.grad_checkpointing and not torch.jit.is_scripting():
                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
                x = checkpoint(r, x, None, None, attn_mask)
            else:
                if idx == 12:
                    x, attn = r(x, attn_mask=attn_mask)
                    out_attn.append(attn)
                else:
                    x, attn_tmp = r(x, attn_mask=attn_mask)
                if idx in out_layers:
                    out_tokens.append(x)
                    # out_tokens = x
        return x, out_attn, out_tokens


class VisionTransformer(nn.Module):
    output_tokens: torch.jit.Final[bool]

    def __init__(
            self,
            image_size: int,
            patch_size: int,
            width: int,
            layers: int,
            heads: int,
            mlp_ratio: float,
            ls_init_value: float = None,
            global_average_pool: bool = False,
            attentional_pool: bool = False,
            n_queries: int = 256,
            attn_pooler_heads: int = 8,
            output_dim: int = 512,
            patch_dropout: float = 0.,
            input_patchnorm: bool = False,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
            output_tokens: bool = False
    ):
        super().__init__()
        self.output_tokens = output_tokens
        image_height, image_width = self.image_size = to_2tuple(image_size)
        patch_height, patch_width = self.patch_size = to_2tuple(patch_size)
        self.grid_size = (image_height // patch_height, image_width // patch_width)
        self.output_dim = output_dim

        # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1
        self.input_patchnorm = input_patchnorm

        if input_patchnorm:
            patch_input_dim = patch_height * patch_width * 3
            self.patchnorm_pre_ln = LayerNorm(patch_input_dim)
            self.conv1 = nn.Linear(patch_input_dim, width)
        else:
            self.patchnorm_pre_ln = nn.Identity()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size,
                                   bias=False)

        # class embeddings and positional embeddings
        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))

        # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
        self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()

        self.ln_pre = norm_layer(width)
        self.transformer = Transformer(
            width,
            layers,
            heads,
            mlp_ratio,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )

        self.global_average_pool = global_average_pool
        if attentional_pool:
            self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries)
            self.ln_post = norm_layer(output_dim)
            self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim))
        else:
            self.attn_pool = None
            self.ln_post = norm_layer(width)
            self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

        self.init_parameters()

    def lock(self, unlocked_groups=0, freeze_bn_stats=False):
        for param in self.parameters():
            param.requires_grad = False

        if unlocked_groups != 0:
            groups = [
                [
                    self.conv1,
                    self.class_embedding,
                    self.positional_embedding,
                    self.ln_pre,
                ],
                *self.transformer.resblocks[:-1],
                [
                    self.transformer.resblocks[-1],
                    self.ln_post,
                ],
                self.proj,
            ]

            def _unlock(x):
                if isinstance(x, Sequence):
                    for g in x:
                        _unlock(g)
                else:
                    if isinstance(x, torch.nn.Parameter):
                        x.requires_grad = True
                    else:
                        for p in x.parameters():
                            p.requires_grad = True

            _unlock(groups[-unlocked_groups:])

    def init_parameters(self):
        # FIXME OpenAI CLIP did not define an init for the VisualTransformer
        # TODO experiment if default PyTorch init, below, or alternate init is best.

        # nn.init.normal_(self.class_embedding, std=self.scale)
        # nn.init.normal_(self.positional_embedding, std=self.scale)
        #
        # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        # attn_std = self.transformer.width ** -0.5
        # fc_std = (2 * self.transformer.width) ** -0.5
        # for block in self.transformer.resblocks:
        #     nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
        #     nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
        #     nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
        #     nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
        #
        # if self.text_projection is not None:
        #     nn.init.normal_(self.text_projection, std=self.scale)
        pass

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.grad_checkpointing = enable

    def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.global_average_pool:
            return x.mean(dim=1), x
        else:
            return x[:, 0], x[:, 1:]

    def forward(self, x: torch.Tensor, out_layers: list):

        # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
        if self.input_patchnorm:
            # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
            x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1],
                          self.patch_size[1])
            x = x.permute(0, 2, 4, 1, 3, 5)
            x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1)
            x = self.patchnorm_pre_ln(x)
            x = self.conv1(x)
        else:
            x = self.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

        # class embeddings and positional embeddings
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)

        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
        x = self.patch_dropout(x)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x, attn, patch_tokens = self.transformer(x, out_layers)
        # attn = attn[0, 0, 1:].view(14, 14)  # 49
        B, C, L = attn[0].shape
        H = int(np.sqrt(L-1))
        out_attn = torch.zeros([H, H]).to('cuda')
        for i in range(len(attn)):
            out_attn += attn[i][0, 0, 1:].view(H, H)
        x = x.permute(1, 0, 2)  # LND -> NLD
        patch_tokens = [patch_tokens[t].permute(1, 0, 2) for t in range(len(patch_tokens))]  # LND -> NLD
        # patch_tokens = patch_tokens.permute(1, 0, 2)  # LND -> NLD

        if self.attn_pool is not None:
            x = self.attn_pool(x)
            x = self.ln_post(x)
            pooled, tokens = self._global_pool(x)
        else:
            pooled, tokens = self._global_pool(x)
            pooled = self.ln_post(pooled)
            # patch_pooled, patch_tokens = self._global_pool(patch_tokens)
            # tokens = self.ln_post(tokens)

        if self.proj is not None:
            pooled = pooled @ self.proj
            # patch_tokens = patch_tokens @ self.proj  # 不知道能不能行
            # tokens = tokens @ self.proj

        if self.output_tokens:
            return pooled, patch_tokens

        return pooled, patch_tokens


class TextTransformer(nn.Module):
    output_tokens: torch.jit.Final[bool]

    def __init__(
            self,
            context_length: int = 77,
            vocab_size: int = 49408,
            width: int = 512,
            heads: int = 8,
            layers: int = 12,
            ls_init_value: float = None,
            output_dim: int = 512,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
            embed_cls: bool = False,
            pad_id: int = 0,
            output_tokens: bool = False,
    ):
        super().__init__()
        self.output_tokens = output_tokens
        self.num_pos = self.context_length = context_length
        self.vocab_size = vocab_size
        self.width = width
        self.output_dim = output_dim
        self.heads = heads
        self.pad_id = pad_id

        self.text_projection = nn.Parameter(torch.empty(width, output_dim))

        if embed_cls:
            self.cls_emb = nn.Parameter(torch.empty(width))
            self.num_pos += 1
        else:
            self.cls_emb = None

        self.token_embedding = nn.Embedding(vocab_size, width)
        self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))
        self.transformer = Transformer(
            width=width,
            layers=layers,
            heads=heads,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        self.ln_final = norm_layer(width)

        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)

        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.token_embedding.weight, std=0.02)
        nn.init.normal_(self.positional_embedding, std=0.01)
        if self.cls_emb is not None:
            nn.init.normal_(self.cls_emb, std=0.01)

        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.width ** -0.5
        fc_std = (2 * self.transformer.width) ** -0.5
        for block in self.transformer.resblocks:
            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

        if self.text_projection is not None:
            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.transformer.grad_checkpointing = enable

    def build_attention_mask(self):
        # lazily create causal attention mask, with full attention between the tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(self.num_pos, self.num_pos)
        mask.fill_(float("-inf"))
        mask.triu_(1)  # zero out the lower diagonal
        return mask

    def build_cls_mask(self, text, cast_dtype: torch.dtype):
        cls_mask = (text != self.pad_id).unsqueeze(1)
        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
        additive_mask.fill_(0)
        additive_mask.masked_fill_(~cls_mask, float("-inf"))
        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
        return additive_mask

    def _repeat(self, t, N: int):
        return t.reshape(1, 1, -1).repeat(N, 1, 1)

    def forward(self, text):
        cast_dtype = self.transformer.get_cast_dtype()
        seq_len = text.shape[1]

        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]
        attn_mask = self.attn_mask
        if self.cls_emb is not None:
            seq_len += 1
            x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)
            cls_mask = self.build_cls_mask(text, cast_dtype)
            attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]

        x = x + self.positional_embedding[:seq_len].to(cast_dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x, attn, patch_tokens = self.transformer(x, attn_mask=attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        if self.cls_emb is not None:
            pooled, tokens = x[:, -1], x[:, :-1]
            pooled = self.ln_final(pooled)
        else:
            x = self.ln_final(x)
            pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x

        if self.text_projection is not None:
            pooled = pooled @ self.text_projection

        if self.output_tokens:
            return pooled, tokens

        return pooled


class MultimodalTransformer(Transformer):
    def __init__(
            self,
            width: int,
            layers: int,
            heads: int,
            context_length: int = 77,
            mlp_ratio: float = 4.0,
            ls_init_value: float = None,
            act_layer: Callable = nn.GELU,
            norm_layer: Callable = LayerNorm,
            output_dim: int = 512,
    ):

        super().__init__(
            width=width,
            layers=layers,
            heads=heads,
            mlp_ratio=mlp_ratio,
            ls_init_value=ls_init_value,
            act_layer=act_layer,
            norm_layer=norm_layer,
        )
        self.context_length = context_length
        self.cross_attn = nn.ModuleList([
            ResidualAttentionBlock(
                width,
                heads,
                mlp_ratio,
                ls_init_value=ls_init_value,
                act_layer=act_layer,
                norm_layer=norm_layer,
                is_cross_attention=True,
            )
            for _ in range(layers)
        ])

        self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)

        self.ln_final = norm_layer(width)
        self.text_projection = nn.Parameter(torch.empty(width, output_dim))

    def init_parameters(self):
        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
        attn_std = self.transformer.
Download .txt
gitextract_7tq5j4qy/

├── LICENSE
├── README.md
├── data/
│   ├── mvtec.py
│   └── visa.py
├── dataset.py
├── exps/
│   └── pretrained/
│       ├── mvtec_pretrained.pth
│       └── visa_pretrained.pth
├── few_shot.py
├── loss.py
├── model.py
├── open_clip/
│   ├── __init__.py
│   ├── coca_model.py
│   ├── constants.py
│   ├── factory.py
│   ├── generation_utils.py
│   ├── hf_configs.py
│   ├── hf_model.py
│   ├── loss.py
│   ├── model.py
│   ├── model_configs/
│   │   ├── RN101-quickgelu.json
│   │   ├── RN101.json
│   │   ├── RN50-quickgelu.json
│   │   ├── RN50.json
│   │   ├── RN50x16.json
│   │   ├── RN50x4.json
│   │   ├── RN50x64.json
│   │   ├── ViT-B-16-plus-240.json
│   │   ├── ViT-B-16-plus.json
│   │   ├── ViT-B-16.json
│   │   ├── ViT-B-32-plus-256.json
│   │   ├── ViT-B-32-quickgelu.json
│   │   ├── ViT-B-32.json
│   │   ├── ViT-H-14.json
│   │   ├── ViT-H-16.json
│   │   ├── ViT-L-14-280.json
│   │   ├── ViT-L-14-336.json
│   │   ├── ViT-L-14.json
│   │   ├── ViT-L-16-320.json
│   │   ├── ViT-L-16.json
│   │   ├── ViT-M-16-alt.json
│   │   ├── ViT-M-16.json
│   │   ├── ViT-M-32-alt.json
│   │   ├── ViT-M-32.json
│   │   ├── ViT-S-16-alt.json
│   │   ├── ViT-S-16.json
│   │   ├── ViT-S-32-alt.json
│   │   ├── ViT-S-32.json
│   │   ├── ViT-bigG-14.json
│   │   ├── ViT-e-14.json
│   │   ├── ViT-g-14.json
│   │   ├── coca_ViT-B-32.json
│   │   ├── coca_ViT-L-14.json
│   │   ├── coca_base.json
│   │   ├── coca_roberta-ViT-B-32.json
│   │   ├── convnext_base.json
│   │   ├── convnext_base_w.json
│   │   ├── convnext_base_w_320.json
│   │   ├── convnext_large.json
│   │   ├── convnext_large_d.json
│   │   ├── convnext_large_d_320.json
│   │   ├── convnext_small.json
│   │   ├── convnext_tiny.json
│   │   ├── convnext_xlarge.json
│   │   ├── convnext_xxlarge.json
│   │   ├── convnext_xxlarge_320.json
│   │   ├── mt5-base-ViT-B-32.json
│   │   ├── mt5-xl-ViT-H-14.json
│   │   ├── roberta-ViT-B-32.json
│   │   ├── swin_base_patch4_window7_224.json
│   │   ├── vit_medium_patch16_gap_256.json
│   │   ├── vit_relpos_medium_patch16_cls_224.json
│   │   ├── xlm-roberta-base-ViT-B-32.json
│   │   └── xlm-roberta-large-ViT-H-14.json
│   ├── modified_resnet.py
│   ├── openai.py
│   ├── pretrained.py
│   ├── push_to_hf_hub.py
│   ├── timm_model.py
│   ├── tokenizer.py
│   ├── transform.py
│   ├── transformer.py
│   ├── utils.py
│   └── version.py
├── prompt_ensemble.py
├── requirements.txt
├── test.py
├── test_few_shot.sh
├── test_zero_shot.sh
├── train.py
└── train.sh
Download .txt
SYMBOL INDEX (227 symbols across 23 files)

FILE: data/mvtec.py
  class MVTecSolver (line 5) | class MVTecSolver(object):
    method __init__ (line 12) | def __init__(self, root='data/mvtec'):
    method run (line 16) | def run(self):

FILE: data/visa.py
  class VisASolver (line 6) | class VisASolver(object):
    method __init__ (line 13) | def __init__(self, root='data/visa'):
    method run (line 19) | def run(self):

FILE: dataset.py
  class VisaDataset (line 10) | class VisaDataset(data.Dataset):
    method __init__ (line 11) | def __init__(self, root, transform, target_transform, mode='test', k_s...
    method __len__ (line 38) | def __len__(self):
    method get_cls_names (line 41) | def get_cls_names(self):
    method __getitem__ (line 44) | def __getitem__(self, index):
  class MVTecDataset (line 63) | class MVTecDataset(data.Dataset):
    method __init__ (line 64) | def __init__(self, root, transform, target_transform, aug_rate, mode='...
    method __len__ (line 92) | def __len__(self):
    method get_cls_names (line 95) | def get_cls_names(self):
    method combine_img (line 98) | def combine_img(self, cls_name):
    method __getitem__ (line 138) | def __getitem__(self, index):

FILE: few_shot.py
  function memory (line 4) | def memory(model_name, model, obj_list, dataset_dir, save_path, preproce...

FILE: loss.py
  class FocalLoss (line 7) | class FocalLoss(nn.Module):
    method __init__ (line 21) | def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_ind...
    method forward (line 34) | def forward(self, logit, target):
  class BinaryDiceLoss (line 89) | class BinaryDiceLoss(nn.Module):
    method __init__ (line 90) | def __init__(self):
    method forward (line 93) | def forward(self, input, targets):

FILE: model.py
  class LinearLayer (line 5) | class LinearLayer(nn.Module):
    method __init__ (line 6) | def __init__(self, dim_in, dim_out, k, model):
    method forward (line 13) | def forward(self, tokens):

FILE: open_clip/coca_model.py
  class MultimodalCfg (line 45) | class MultimodalCfg(CLIPTextCfg):
  function _build_text_decoder_tower (line 53) | def _build_text_decoder_tower(
  class CoCa (line 79) | class CoCa(nn.Module):
    method __init__ (line 80) | def __init__(
    method set_grad_checkpointing (line 126) | def set_grad_checkpointing(self, enable=True):
    method _encode_image (line 135) | def _encode_image(self, images, out_layers, normalize=True):
    method _encode_text (line 139) | def _encode_text(self, text, normalize=True, embed_cls=True):
    method encode_image (line 148) | def encode_image(self, images, out_layers, normalize=True):
    method encode_text (line 152) | def encode_text(self, text, normalize=True, embed_cls=True):
    method forward (line 156) | def forward(self, image, text, embed_cls=True, image_latent=None, imag...
    method generate (line 173) | def generate(
    method _generate_beamsearch (line 296) | def _generate_beamsearch(
  function prepare_inputs_for_generation (line 445) | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **...

FILE: open_clip/factory.py
  function _natural_key (line 29) | def _natural_key(string_):
  function _rescan_model_configs (line 33) | def _rescan_model_configs():
  function list_models (line 57) | def list_models():
  function add_model_config (line 62) | def add_model_config(path):
  function get_model_config (line 70) | def get_model_config(model_name):
  function get_tokenizer (line 77) | def get_tokenizer(model_name):
  function load_state_dict (line 87) | def load_state_dict(checkpoint_path: str, map_location='cpu'):
  function load_checkpoint (line 98) | def load_checkpoint(model, checkpoint_path, strict=True):
  function create_model (line 108) | def create_model(
  function create_loss (line 286) | def create_loss(args):
  function create_model_and_transforms (line 317) | def create_model_and_transforms(
  function create_model_from_pretrained (line 372) | def create_model_from_pretrained(

FILE: open_clip/hf_model.py
  class BaseModelOutput (line 21) | class BaseModelOutput:
  class PretrainedConfig (line 25) | class PretrainedConfig:
  function _camel2snake (line 32) | def _camel2snake(s):
  function register_pooler (line 40) | def register_pooler(cls):
  class MeanPooler (line 47) | class MeanPooler(nn.Module):
    method forward (line 50) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class MaxPooler (line 56) | class MaxPooler(nn.Module):
    method forward (line 59) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class ClsPooler (line 65) | class ClsPooler(nn.Module):
    method __init__ (line 68) | def __init__(self, use_pooler_output=True):
    method forward (line 73) | def forward(self, x: BaseModelOutput, attention_mask: TensorType):
  class HFTextEncoder (line 83) | class HFTextEncoder(nn.Module):
    method __init__ (line 87) | def __init__(
    method forward (line 137) | def forward(self, x: TensorType):
    method lock (line 154) | def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
    method set_grad_checkpointing (line 172) | def set_grad_checkpointing(self, enable=True):
    method init_parameters (line 175) | def init_parameters(self):

FILE: open_clip/loss.py
  function gather_features (line 19) | def gather_features(
  class ClipLoss (line 66) | class ClipLoss(nn.Module):
    method __init__ (line 68) | def __init__(
    method get_ground_truth (line 89) | def get_ground_truth(self, device, num_logits) -> torch.Tensor:
    method get_logits (line 102) | def get_logits(self, image_features, text_features, logit_scale):
    method forward (line 120) | def forward(self, image_features, text_features, logit_scale, output_d...
  class CoCaLoss (line 134) | class CoCaLoss(ClipLoss):
    method __init__ (line 135) | def __init__(
    method forward (line 160) | def forward(self, image_features, text_features, logits, labels, logit...
  class DistillClipLoss (line 176) | class DistillClipLoss(ClipLoss):
    method dist_loss (line 178) | def dist_loss(self, teacher_logits, student_logits):
    method forward (line 181) | def forward(

FILE: open_clip/model.py
  class CLIPVisionCfg (line 24) | class CLIPVisionCfg:
  class CLIPTextCfg (line 49) | class CLIPTextCfg:
  function get_cast_dtype (line 66) | def get_cast_dtype(precision: str):
  function _build_vision_tower (line 75) | def _build_vision_tower(
  function _build_text_tower (line 137) | def _build_text_tower(
  class CLIP (line 176) | class CLIP(nn.Module):
    method __init__ (line 179) | def __init__(
    method lock_image_tower (line 203) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 208) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 212) | def encode_image(self, image, out_layers, normalize: bool = False):
    method encode_text (line 216) | def encode_text(self, text, normalize: bool = False):
    method forward (line 230) | def forward(self, image, text):
  class CustomTextCLIP (line 242) | class CustomTextCLIP(nn.Module):
    method __init__ (line 245) | def __init__(
    method lock_image_tower (line 260) | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
    method lock_text_tower (line 264) | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm:...
    method set_grad_checkpointing (line 268) | def set_grad_checkpointing(self, enable=True):
    method encode_image (line 272) | def encode_image(self, image, normalize: bool = False):
    method encode_text (line 276) | def encode_text(self, text, normalize: bool = False):
    method forward (line 280) | def forward(self, image, text):
  function convert_weights_to_lp (line 292) | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
  function convert_to_custom_text_state_dict (line 320) | def convert_to_custom_text_state_dict(state_dict: dict):
  function build_model_from_openai_state_dict (line 338) | def build_model_from_openai_state_dict(
  function trace_model (line 398) | def trace_model(model, batch_size=256, device=torch.device('cpu')):
  function resize_pos_embed (line 414) | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', ...

FILE: open_clip/modified_resnet.py
  class Bottleneck (line 10) | class Bottleneck(nn.Module):
    method __init__ (line 13) | def __init__(self, inplanes, planes, stride=1):
    method forward (line 42) | def forward(self, x: torch.Tensor):
  class AttentionPool2d (line 58) | class AttentionPool2d(nn.Module):
    method __init__ (line 59) | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, o...
    method forward (line 68) | def forward(self, x):
  class ModifiedResNet (line 95) | class ModifiedResNet(nn.Module):
    method __init__ (line 103) | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
    method _make_layer (line 132) | def _make_layer(self, planes, blocks, stride=1):
    method init_parameters (line 141) | def init_parameters(self):
    method lock (line 154) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 162) | def set_grad_checkpointing(self, enable=True):
    method stem (line 166) | def stem(self, x):
    method forward (line 173) | def forward(self, x, out_blocks):

FILE: open_clip/openai.py
  function list_openai_models (line 18) | def list_openai_models() -> List[str]:
  function load_openai_model (line 23) | def load_openai_model(

FILE: open_clip/pretrained.py
  function _pcfg (line 21) | def _pcfg(url='', hf_hub='', mean=None, std=None):
  function _clean_tag (line 235) | def _clean_tag(tag: str):
  function list_pretrained (line 240) | def list_pretrained(as_str: bool = False):
  function list_pretrained_models_by_tag (line 247) | def list_pretrained_models_by_tag(tag: str):
  function list_pretrained_tags_by_model (line 257) | def list_pretrained_tags_by_model(model: str):
  function is_pretrained_cfg (line 265) | def is_pretrained_cfg(model: str, tag: str):
  function get_pretrained_cfg (line 271) | def get_pretrained_cfg(model: str, tag: str):
  function get_pretrained_url (line 278) | def get_pretrained_url(model: str, tag: str):
  function download_pretrained_from_url (line 283) | def download_pretrained_from_url(
  function has_hf_hub (line 329) | def has_hf_hub(necessary=False):
  function download_pretrained_from_hf (line 337) | def download_pretrained_from_hf(
  function download_pretrained (line 348) | def download_pretrained(

FILE: open_clip/push_to_hf_hub.py
  function save_config_for_hf (line 27) | def save_config_for_hf(
  function save_for_hf (line 45) | def save_for_hf(
  function push_to_hf_hub (line 65) | def push_to_hf_hub(
  function push_pretrained_to_hf_hub (line 124) | def push_pretrained_to_hf_hub(
  function generate_readme (line 163) | def generate_readme(model_card: dict, model_name: str):

FILE: open_clip/timm_model.py
  class TimmModel (line 28) | class TimmModel(nn.Module):
    method __init__ (line 33) | def __init__(
    method lock (line 85) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method set_grad_checkpointing (line 118) | def set_grad_checkpointing(self, enable=True):
    method forward (line 124) | def forward(self, x):

FILE: open_clip/tokenizer.py
  function default_bpe (line 21) | def default_bpe():
  function bytes_to_unicode (line 26) | def bytes_to_unicode():
  function get_pairs (line 48) | def get_pairs(word):
  function basic_clean (line 60) | def basic_clean(text):
  function whitespace_clean (line 66) | def whitespace_clean(text):
  class SimpleTokenizer (line 72) | class SimpleTokenizer(object):
    method __init__ (line 73) | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
    method bpe (line 98) | def bpe(self, token):
    method encode (line 139) | def encode(self, text):
    method decode (line 147) | def decode(self, tokens):
  function decode (line 155) | def decode(output_ids: torch.Tensor):
  function tokenize (line 159) | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> ...
  class HFTokenizer (line 191) | class HFTokenizer:
    method __init__ (line 194) | def __init__(self, tokenizer_name: str):
    method save_pretrained (line 198) | def save_pretrained(self, dest):
    method __call__ (line 201) | def __call__(self, texts: Union[str, List[str]], context_length: int =...

FILE: open_clip/transform.py
  class AugmentationCfg (line 16) | class AugmentationCfg:
  class ResizeMaxSize (line 26) | class ResizeMaxSize(nn.Module):
    method __init__ (line 28) | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, ...
    method forward (line 37) | def forward(self, img):
  function _convert_to_rgb (line 52) | def _convert_to_rgb(image):
  function image_transform (line 56) | def image_transform(

FILE: open_clip/transformer.py
  class LayerNormFp32 (line 14) | class LayerNormFp32(nn.LayerNorm):
    method forward (line 17) | def forward(self, x: torch.Tensor):
  class LayerNorm (line 23) | class LayerNorm(nn.LayerNorm):
    method forward (line 26) | def forward(self, x: torch.Tensor):
  class QuickGELU (line 32) | class QuickGELU(nn.Module):
    method forward (line 34) | def forward(self, x: torch.Tensor):
  class LayerScale (line 38) | class LayerScale(nn.Module):
    method __init__ (line 39) | def __init__(self, dim, init_values=1e-5, inplace=False):
    method forward (line 44) | def forward(self, x):
  class PatchDropout (line 48) | class PatchDropout(nn.Module):
    method __init__ (line 53) | def __init__(self, prob, exclude_first_token=True):
    method forward (line 59) | def forward(self, x):
  class Attention (line 88) | class Attention(nn.Module):
    method __init__ (line 89) | def __init__(
    method forward (line 128) | def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
  class AttentionalPooler (line 164) | class AttentionalPooler(nn.Module):
    method __init__ (line 165) | def __init__(
    method forward (line 179) | def forward(self, x: torch.Tensor):
    method _repeat (line 186) | def _repeat(self, query, N: int):
  class ResidualAttentionBlock (line 190) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 191) | def __init__(
    method attention (line 221) | def attention(
    method forward (line 236) | def forward(
  class CustomResidualAttentionBlock (line 252) | class CustomResidualAttentionBlock(nn.Module):
    method __init__ (line 253) | def __init__(
    method forward (line 287) | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] =...
  class Transformer (line 293) | class Transformer(nn.Module):
    method __init__ (line 294) | def __init__(
    method get_cast_dtype (line 316) | def get_cast_dtype(self) -> torch.dtype:
    method forward (line 319) | def forward(self, x: torch.Tensor, out_layers: list = [3, 6, 9],
  class VisionTransformer (line 342) | class VisionTransformer(nn.Module):
    method __init__ (line 345) | def __init__(
    method lock (line 415) | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
    method init_parameters (line 448) | def init_parameters(self):
    method set_grad_checkpointing (line 469) | def set_grad_checkpointing(self, enable=True):
    method _global_pool (line 472) | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.T...
    method forward (line 478) | def forward(self, x: torch.Tensor, out_layers: list):
  class TextTransformer (line 537) | class TextTransformer(nn.Module):
    method __init__ (line 540) | def __init__(
    method init_parameters (line 588) | def init_parameters(self):
    method set_grad_checkpointing (line 607) | def set_grad_checkpointing(self, enable=True):
    method build_attention_mask (line 610) | def build_attention_mask(self):
    method build_cls_mask (line 618) | def build_cls_mask(self, text, cast_dtype: torch.dtype):
    method _repeat (line 627) | def _repeat(self, t, N: int):
    method forward (line 630) | def forward(self, text):
  class MultimodalTransformer (line 665) | class MultimodalTransformer(Transformer):
    method __init__ (line 666) | def __init__(
    method init_parameters (line 707) | def init_parameters(self):
    method build_attention_mask (line 725) | def build_attention_mask(self):
    method forward (line 733) | def forward(self, image_embs, text_embs):
    method set_grad_checkpointing (line 756) | def set_grad_checkpointing(self, enable=True):

FILE: open_clip/utils.py
  function freeze_batch_norm_2d (line 8) | def freeze_batch_norm_2d(module, module_match={}, name=''):
  function _ntuple (line 48) | def _ntuple(n):

FILE: prompt_ensemble.py
  function encode_text_with_prompt_ensemble (line 8) | def encode_text_with_prompt_ensemble(model, objs, tokenizer, device):

FILE: test.py
  function setup_seed (line 23) | def setup_seed(seed):
  function normalize (line 32) | def normalize(pred, max_value=None, min_value=None):
  function apply_ad_scoremap (line 39) | def apply_ad_scoremap(image, scoremap, alpha=0.5):
  function cal_pro_score (line 47) | def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
  function test (line 74) | def test(args):

FILE: train.py
  function setup_seed (line 22) | def setup_seed(seed):
  function train (line 31) | def train(args):
Condensed preview — 90 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (236K chars).
[
  {
    "path": "LICENSE",
    "chars": 1067,
    "preview": "MIT License\n\nCopyright (c) 2023 Xuhai Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy"
  },
  {
    "path": "README.md",
    "chars": 5705,
    "preview": "[Workshop Link](https://sites.google.com/view/vand-cvpr23/home) | [Challenge Link](https://sites.google.com/view/vand-cv"
  },
  {
    "path": "data/mvtec.py",
    "chars": 1865,
    "preview": "import os\r\nimport json\r\n\r\n\r\nclass MVTecSolver(object):\r\n    CLSNAMES = [\r\n        'bottle', 'cable', 'capsule', 'carpet'"
  },
  {
    "path": "data/visa.py",
    "chars": 1772,
    "preview": "import os\r\nimport json\r\nimport pandas as pd\r\n\r\n\r\nclass VisASolver(object):\r\n    CLSNAMES = [\r\n        'candle', 'capsule"
  },
  {
    "path": "dataset.py",
    "chars": 5780,
    "preview": "import torch.utils.data as data\r\nimport json\r\nimport random\r\nfrom PIL import Image\r\nimport numpy as np\r\nimport torch\r\nim"
  },
  {
    "path": "few_shot.py",
    "chars": 1496,
    "preview": "import torch\r\nfrom dataset import VisaDataset, MVTecDataset\r\n\r\ndef memory(model_name, model, obj_list, dataset_dir, save"
  },
  {
    "path": "loss.py",
    "chars": 4078,
    "preview": "import numpy as np\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nfrom math import exp\r\n\r\nclass "
  },
  {
    "path": "model.py",
    "chars": 800,
    "preview": "from torch import Tensor, nn\r\nimport torch\r\nfrom torch.nn import functional as F\r\n\r\nclass LinearLayer(nn.Module):\r\n    d"
  },
  {
    "path": "open_clip/__init__.py",
    "chars": 963,
    "preview": "from .coca_model import CoCa\nfrom .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD\nfrom .factory import create_"
  },
  {
    "path": "open_clip/coca_model.py",
    "chars": 17824,
    "preview": "from typing import Optional\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nf"
  },
  {
    "path": "open_clip/constants.py",
    "chars": 116,
    "preview": "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\nOPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\n"
  },
  {
    "path": "open_clip/factory.py",
    "chars": 15355,
    "preview": "import json\nimport logging\nimport os\nimport pathlib\nimport re\nimport numpy as np\nfrom copy import deepcopy\nfrom pathlib "
  },
  {
    "path": "open_clip/generation_utils.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "open_clip/hf_configs.py",
    "chars": 1675,
    "preview": "# HF architecture dict:\narch_dict = {\n    # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n    \"robe"
  },
  {
    "path": "open_clip/hf_model.py",
    "chars": 6298,
    "preview": "\"\"\" huggingface model adapter\n\nWraps HuggingFace transformers (https://github.com/huggingface/transformers) models for u"
  },
  {
    "path": "open_clip/loss.py",
    "chars": 7943,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import functional as F\n\ntry:\n    import torch.distributed.nn\n    from t"
  },
  {
    "path": "open_clip/model.py",
    "chars": 17866,
    "preview": "\"\"\" CLIP Model\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\nfrom"
  },
  {
    "path": "open_clip/model_configs/RN101-quickgelu.json",
    "chars": 388,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n     "
  },
  {
    "path": "open_clip/model_configs/RN101.json",
    "chars": 364,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4,"
  },
  {
    "path": "open_clip/model_configs/RN50-quickgelu.json",
    "chars": 389,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n    "
  },
  {
    "path": "open_clip/model_configs/RN50.json",
    "chars": 364,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": [\n            3,\n            4"
  },
  {
    "path": "open_clip/model_configs/RN50x16.json",
    "chars": 365,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 384,\n        \"layers\": [\n            6,\n            8,"
  },
  {
    "path": "open_clip/model_configs/RN50x4.json",
    "chars": 365,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 288,\n        \"layers\": [\n            4,\n            6,"
  },
  {
    "path": "open_clip/model_configs/RN50x64.json",
    "chars": 370,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 448,\n        \"layers\": [\n            3,\n            1"
  },
  {
    "path": "open_clip/model_configs/ViT-B-16-plus-240.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 240,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-B-16-plus.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-B-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-B-32-plus-256.json",
    "chars": 295,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"image_size\": 256,\n        \"layers\": 12,\n        \"width\": 896,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-B-32-quickgelu.json",
    "chars": 318,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n   "
  },
  {
    "path": "open_clip/model_configs/ViT-B-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-H-14.json",
    "chars": 324,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "open_clip/model_configs/ViT-H-16.json",
    "chars": 324,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "open_clip/model_configs/ViT-L-14-280.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 280,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/ViT-L-14-336.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 336,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/ViT-L-14.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/ViT-L-16-320.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 320,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/ViT-L-16.json",
    "chars": 296,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/ViT-M-16-alt.json",
    "chars": 325,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-M-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-M-32-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-M-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 512,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-S-16-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-S-16.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-S-32-alt.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 256,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-S-32.json",
    "chars": 294,
    "preview": "{\n    \"embed_dim\": 384,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 384,\n     "
  },
  {
    "path": "open_clip/model_configs/ViT-bigG-14.json",
    "chars": 354,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 48,\n        \"width\": 1664,\n   "
  },
  {
    "path": "open_clip/model_configs/ViT-e-14.json",
    "chars": 354,
    "preview": "{\n    \"embed_dim\": 1280,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 56,\n        \"width\": 1792,\n   "
  },
  {
    "path": "open_clip/model_configs/ViT-g-14.json",
    "chars": 353,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 40,\n        \"width\": 1408,\n   "
  },
  {
    "path": "open_clip/model_configs/coca_ViT-B-32.json",
    "chars": 659,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/coca_ViT-L-14.json",
    "chars": 664,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 24,\n        \"width\": 1024,\n    "
  },
  {
    "path": "open_clip/model_configs/coca_base.json",
    "chars": 669,
    "preview": "{\n    \"embed_dim\": 512,\n    \"multimodal_cfg\": {\n        \"width\": 768,\n        \"context_length\": 76,\n        \"vocab_size\""
  },
  {
    "path": "open_clip/model_configs/coca_roberta-ViT-B-32.json",
    "chars": 517,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/convnext_base.json",
    "chars": 421,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "open_clip/model_configs/convnext_base_w.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "open_clip/model_configs/convnext_base_w_320.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_base\",\n        \"timm_model_pretrained\":"
  },
  {
    "path": "open_clip/model_configs/convnext_large.json",
    "chars": 423,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "open_clip/model_configs/convnext_large_d.json",
    "chars": 420,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "open_clip/model_configs/convnext_large_d_320.json",
    "chars": 420,
    "preview": "{\n    \"embed_dim\": 768,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_large\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "open_clip/model_configs/convnext_small.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_small\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "open_clip/model_configs/convnext_tiny.json",
    "chars": 422,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_tiny\",\n        \"timm_model_pretrained\""
  },
  {
    "path": "open_clip/model_configs/convnext_xlarge.json",
    "chars": 426,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xlarge\",\n        \"timm_model_pretraine"
  },
  {
    "path": "open_clip/model_configs/convnext_xxlarge.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrain"
  },
  {
    "path": "open_clip/model_configs/convnext_xxlarge_320.json",
    "chars": 427,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"convnext_xxlarge\",\n        \"timm_model_pretrain"
  },
  {
    "path": "open_clip/model_configs/mt5-base-ViT-B-32.json",
    "chars": 325,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/mt5-xl-ViT-H-14.json",
    "chars": 349,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "open_clip/model_configs/roberta-ViT-B-32.json",
    "chars": 343,
    "preview": "{\n    \"embed_dim\": 512,\n    \"quick_gelu\": true,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n   "
  },
  {
    "path": "open_clip/model_configs/swin_base_patch4_window7_224.json",
    "chars": 380,
    "preview": "{\n    \"embed_dim\": 640,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"swin_base_patch4_window7_224\",\n        \"timm_mod"
  },
  {
    "path": "open_clip/model_configs/vit_medium_patch16_gap_256.json",
    "chars": 377,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_medium_patch16_gap_256\",\n        \"timm_model"
  },
  {
    "path": "open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json",
    "chars": 384,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"timm_model_name\": \"vit_relpos_medium_patch16_cls_224\",\n        \"tim"
  },
  {
    "path": "open_clip/model_configs/xlm-roberta-base-ViT-B-32.json",
    "chars": 327,
    "preview": "{\n    \"embed_dim\": 512,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 12,\n        \"width\": 768,\n     "
  },
  {
    "path": "open_clip/model_configs/xlm-roberta-large-ViT-H-14.json",
    "chars": 357,
    "preview": "{\n    \"embed_dim\": 1024,\n    \"vision_cfg\": {\n        \"image_size\": 224,\n        \"layers\": 32,\n        \"width\": 1280,\n   "
  },
  {
    "path": "open_clip/modified_resnet.py",
    "chars": 7216,
    "preview": "from collections import OrderedDict\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom open_c"
  },
  {
    "path": "open_clip/openai.py",
    "chars": 5446,
    "preview": "\"\"\" OpenAI pretrained model functions\n\nAdapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c"
  },
  {
    "path": "open_clip/pretrained.py",
    "chars": 14144,
    "preview": "import hashlib\nimport os\nimport urllib\nimport warnings\nfrom functools import partial\nfrom typing import Dict, Union\n\nfro"
  },
  {
    "path": "open_clip/push_to_hf_hub.py",
    "chars": 7660,
    "preview": "import argparse\nimport json\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import Optional"
  },
  {
    "path": "open_clip/timm_model.py",
    "chars": 5077,
    "preview": "\"\"\" timm model adapter\n\nWraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower "
  },
  {
    "path": "open_clip/tokenizer.py",
    "chars": 7407,
    "preview": "\"\"\" CLIP tokenizer\n\nCopied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n\"\"\"\ni"
  },
  {
    "path": "open_clip/transform.py",
    "chars": 4807,
    "preview": "import warnings\nfrom dataclasses import dataclass, asdict\nfrom typing import Any, Dict, Optional, Sequence, Tuple, Union"
  },
  {
    "path": "open_clip/transformer.py",
    "chars": 28469,
    "preview": "from collections import OrderedDict\nimport math\nfrom typing import Callable, Optional, Sequence, Tuple\n\nimport torch\nfro"
  },
  {
    "path": "open_clip/utils.py",
    "chars": 2223,
    "preview": "from itertools import repeat\nimport collections.abc\n\nfrom torch import nn as nn\nfrom torchvision.ops.misc import FrozenB"
  },
  {
    "path": "open_clip/version.py",
    "chars": 23,
    "preview": "__version__ = '2.16.0'\n"
  },
  {
    "path": "prompt_ensemble.py",
    "chars": 2425,
    "preview": "import os\r\nfrom typing import Union, List\r\nfrom pkg_resources import packaging\r\nimport torch\r\nimport numpy as np\r\n\r\n\r\nde"
  },
  {
    "path": "requirements.txt",
    "chars": 298,
    "preview": "ftfy==6.1.1\nhorovod==0.28.1\nhuggingface_hub==0.13.4\nnumpy==1.21.6\nopencv_python==4.6.0.66\npandas==1.3.5\nPillow==9.2.0\nre"
  },
  {
    "path": "test.py",
    "chars": 14417,
    "preview": "import os\r\nimport cv2\r\nimport json\r\nimport torch\r\nimport random\r\nimport logging\r\nimport argparse\r\nimport numpy as np\r\nfr"
  },
  {
    "path": "test_few_shot.sh",
    "chars": 824,
    "preview": "### test on the VisA dataset\npython test.py --mode few_shot --dataset visa \\\n--data_path ./data/visa --save_path ./resul"
  },
  {
    "path": "test_zero_shot.sh",
    "chars": 695,
    "preview": "### test on the VisA dataset\npython test.py --mode zero_shot --dataset visa \\\n--data_path ./data/visa --save_path ./resu"
  },
  {
    "path": "train.py",
    "chars": 7380,
    "preview": "import torch\r\nimport torch.nn as nn\r\nimport numpy as np\r\nimport random\r\nimport os\r\nimport json\r\nimport argparse\r\nfrom to"
  },
  {
    "path": "train.sh",
    "chars": 698,
    "preview": "### train on the MVTec AD dataset\npython train.py --dataset mvtec --train_data_path ./data/mvtec \\\n--save_path ./exps/vi"
  }
]

// ... and 2 more files (download for full content)

About this extraction

This page contains the full source code of the ByChelsea/VAND-APRIL-GAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 90 files (24.2 MB), approximately 58.7k tokens, and a symbol index with 227 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!